From 654d1c2fe88cf4ed718d9114cd8d7801290d13c1 Mon Sep 17 00:00:00 2001 From: Urtzi Alfaro Date: Thu, 17 Jul 2025 19:54:04 +0200 Subject: [PATCH] Fix gateway --- gateway/app/main.py | 50 +++------- gateway/app/middleware/auth.py | 135 +++++++++++++------------ gateway/app/middleware/logging.py | 83 +++++++++------- gateway/app/middleware/rate_limit.py | 142 ++++++++++++++------------- 4 files changed, 203 insertions(+), 207 deletions(-) diff --git a/gateway/app/main.py b/gateway/app/main.py index 03239a0f..84d9a89b 100644 --- a/gateway/app/main.py +++ b/gateway/app/main.py @@ -14,9 +14,9 @@ from typing import Dict, Any from app.core.config import settings from app.core.service_discovery import ServiceDiscovery -from app.middleware.auth import auth_middleware -from app.middleware.logging import logging_middleware -from app.middleware.rate_limit import rate_limit_middleware +from app.middleware.auth import AuthMiddleware +from app.middleware.logging import LoggingMiddleware +from app.middleware.rate_limit import RateLimitMiddleware from app.routes import auth, training, forecasting, data, tenant, notification from shared.monitoring.logging import setup_logging from shared.monitoring.metrics import MetricsCollector @@ -40,7 +40,7 @@ metrics_collector = MetricsCollector("gateway") # Service discovery service_discovery = ServiceDiscovery() -# CORS middleware - FIXED: Use the parsed list property +# CORS middleware - Add first app.add_middleware( CORSMiddleware, allow_origins=settings.CORS_ORIGINS_LIST, @@ -49,10 +49,10 @@ app.add_middleware( allow_headers=["*"], ) -# Custom middleware -app.add_middleware(auth_middleware) -app.add_middleware(logging_middleware) -app.add_middleware(rate_limit_middleware) +# Custom middleware - Add in correct order (outer to inner) +app.add_middleware(LoggingMiddleware) +app.add_middleware(RateLimitMiddleware, calls_per_minute=60) +app.add_middleware(AuthMiddleware) # Include routers app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"]) @@ -88,43 +88,17 @@ async def shutdown_event(): @app.get("/health") async def health_check(): """Health check endpoint""" - healthy_services = await service_discovery.get_healthy_services() - return { "status": "healthy", - "service": "gateway", + "service": "api-gateway", "version": "1.0.0", - "healthy_services": healthy_services, - "total_services": len(settings.SERVICES), "timestamp": time.time() } @app.get("/metrics") -async def get_metrics(): - """Get basic metrics""" - return { - "service": "gateway", - "uptime": time.time() - app.state.start_time if hasattr(app.state, 'start_time') else 0, - "healthy_services": await service_discovery.get_healthy_services() - } - -@app.exception_handler(HTTPException) -async def http_exception_handler(request: Request, exc: HTTPException): - """Handle HTTP exceptions""" - logger.error(f"HTTP {exc.status_code}: {exc.detail}") - return JSONResponse( - status_code=exc.status_code, - content={"detail": exc.detail, "service": "gateway"} - ) - -@app.exception_handler(Exception) -async def general_exception_handler(request: Request, exc: Exception): - """Handle general exceptions""" - logger.error(f"Unhandled exception: {exc}", exc_info=True) - return JSONResponse( - status_code=500, - content={"detail": "Internal server error", "service": "gateway"} - ) +async def metrics(): + """Metrics endpoint for monitoring""" + return {"metrics": "enabled"} if __name__ == "__main__": import uvicorn diff --git a/gateway/app/middleware/auth.py b/gateway/app/middleware/auth.py index ef45fd35..b5e712ff 100644 --- a/gateway/app/middleware/auth.py +++ b/gateway/app/middleware/auth.py @@ -3,8 +3,10 @@ Authentication middleware for gateway """ import logging -from fastapi import Request, HTTPException +from fastapi import Request from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import Response import httpx from typing import Optional @@ -28,74 +30,77 @@ PUBLIC_ROUTES = [ "/api/v1/auth/refresh" ] -async def auth_middleware(request: Request, call_next): - """Authentication middleware""" +class AuthMiddleware(BaseHTTPMiddleware): + """Authentication middleware class""" - # Check if route requires authentication - if _is_public_route(request.url.path): - return await call_next(request) - - # Get token from header - token = _extract_token(request) - if not token: - return JSONResponse( - status_code=401, - content={"detail": "Authentication required"} - ) - - # Verify token - try: - # First try to verify token locally - payload = jwt_handler.verify_token(token) + async def dispatch(self, request: Request, call_next) -> Response: + """Process request with authentication""" - if payload: - # Add user info to request state - request.state.user = payload + # Check if route requires authentication + if self._is_public_route(request.url.path): return await call_next(request) - else: - # Token invalid or expired, verify with auth service - user_info = await _verify_with_auth_service(token) - if user_info: - request.state.user = user_info + + # Get token from header + token = self._extract_token(request) + if not token: + return JSONResponse( + status_code=401, + content={"detail": "Authentication required"} + ) + + # Verify token + try: + # First try to verify token locally + payload = jwt_handler.verify_token(token) + + if payload: + # Add user info to request state + request.state.user = payload return await call_next(request) else: - return JSONResponse( - status_code=401, - content={"detail": "Invalid or expired token"} + # Token invalid or expired, verify with auth service + user_info = await self._verify_with_auth_service(token) + if user_info: + request.state.user = user_info + return await call_next(request) + else: + return JSONResponse( + status_code=401, + content={"detail": "Invalid or expired token"} + ) + + except Exception as e: + logger.error(f"Authentication error: {e}") + return JSONResponse( + status_code=401, + content={"detail": "Authentication failed"} + ) + + def _is_public_route(self, path: str) -> bool: + """Check if route is public""" + return any(path.startswith(route) for route in PUBLIC_ROUTES) + + def _extract_token(self, request: Request) -> Optional[str]: + """Extract JWT token from request""" + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + return auth_header.split(" ")[1] + return None + + async def _verify_with_auth_service(self, token: str) -> Optional[dict]: + """Verify token with auth service""" + try: + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{settings.AUTH_SERVICE_URL}/verify", + headers={"Authorization": f"Bearer {token}"} ) - except Exception as e: - logger.error(f"Authentication error: {e}") - return JSONResponse( - status_code=401, - content={"detail": "Authentication failed"} - ) - -def _is_public_route(path: str) -> bool: - """Check if route is public""" - return any(path.startswith(route) for route in PUBLIC_ROUTES) - -def _extract_token(request: Request) -> Optional[str]: - """Extract JWT token from request""" - auth_header = request.headers.get("Authorization") - if auth_header and auth_header.startswith("Bearer "): - return auth_header.split(" ")[1] - return None - -async def _verify_with_auth_service(token: str) -> Optional[dict]: - """Verify token with auth service""" - try: - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"{settings.AUTH_SERVICE_URL}/verify", - headers={"Authorization": f"Bearer {token}"} - ) - - if response.status_code == 200: - return response.json() - else: - return None - - except Exception as e: - logger.error(f"Auth service verification failed: {e}") - return None \ No newline at end of file + if response.status_code == 200: + return response.json() + else: + return None + + except Exception as e: + logger.error(f"Auth service verification failed: {e}") + return None diff --git a/gateway/app/middleware/logging.py b/gateway/app/middleware/logging.py index ea565b56..80640a40 100644 --- a/gateway/app/middleware/logging.py +++ b/gateway/app/middleware/logging.py @@ -5,44 +5,53 @@ Logging middleware for gateway import logging import time from fastapi import Request -import json +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import Response +import uuid logger = logging.getLogger(__name__) -async def logging_middleware(request: Request, call_next): - """Logging middleware""" +class LoggingMiddleware(BaseHTTPMiddleware): + """Logging middleware class""" - start_time = time.time() - - # Log request - logger.info( - f"Request: {request.method} {request.url.path}", - extra={ - "method": request.method, - "url": request.url.path, - "query_params": str(request.query_params), - "client_host": request.client.host, - "user_agent": request.headers.get("user-agent", ""), - "request_id": getattr(request.state, 'request_id', None) - } - ) - - # Process request - response = await call_next(request) - - # Calculate duration - duration = time.time() - start_time - - # Log response - logger.info( - f"Response: {response.status_code} in {duration:.3f}s", - extra={ - "status_code": response.status_code, - "duration": duration, - "method": request.method, - "url": request.url.path, - "request_id": getattr(request.state, 'request_id', None) - } - ) - - return response \ No newline at end of file + async def dispatch(self, request: Request, call_next) -> Response: + """Process request with logging""" + + start_time = time.time() + + # Generate request ID + request_id = str(uuid.uuid4()) + request.state.request_id = request_id + + # Log request + logger.info( + f"Request: {request.method} {request.url.path}", + extra={ + "method": request.method, + "url": request.url.path, + "query_params": str(request.query_params), + "client_host": request.client.host if request.client else "unknown", + "user_agent": request.headers.get("user-agent", ""), + "request_id": request_id + } + ) + + # Process request + response = await call_next(request) + + # Calculate duration + duration = time.time() - start_time + + # Log response + logger.info( + f"Response: {response.status_code} in {duration:.3f}s", + extra={ + "status_code": response.status_code, + "duration": duration, + "method": request.method, + "url": request.url.path, + "request_id": request_id + } + ) + + return response \ No newline at end of file diff --git a/gateway/app/middleware/rate_limit.py b/gateway/app/middleware/rate_limit.py index 84cdc49f..6cd50c00 100644 --- a/gateway/app/middleware/rate_limit.py +++ b/gateway/app/middleware/rate_limit.py @@ -3,83 +3,91 @@ Rate limiting middleware for gateway """ import logging -from fastapi import Request, HTTPException +import time +from fastapi import Request from fastapi.responses import JSONResponse -import redis.asyncio as redis -from datetime import datetime, timedelta -import hashlib - -from app.core.config import settings +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import Response +from typing import Dict, Optional +import asyncio logger = logging.getLogger(__name__) -# Redis client for rate limiting -redis_client = redis.from_url(settings.REDIS_URL) - -async def rate_limit_middleware(request: Request, call_next): - """Rate limiting middleware""" +class RateLimitMiddleware(BaseHTTPMiddleware): + """Rate limiting middleware class""" - # Skip rate limiting for health checks - if request.url.path in ["/health", "/metrics"]: + def __init__(self, app, calls_per_minute: int = 60): + super().__init__(app) + self.calls_per_minute = calls_per_minute + self.requests: Dict[str, list] = {} + self._cleanup_task = None + + async def dispatch(self, request: Request, call_next) -> Response: + """Process request with rate limiting""" + + # Skip rate limiting for health checks + if request.url.path in ["/health", "/metrics"]: + return await call_next(request) + + # Get client identifier + client_id = self._get_client_id(request) + + # Check rate limit + if self._is_rate_limited(client_id): + logger.warning(f"Rate limit exceeded for client: {client_id}") + return JSONResponse( + status_code=429, + content={"detail": "Rate limit exceeded"} + ) + + # Record request + self._record_request(client_id) + + # Process request return await call_next(request) - # Get client identifier (IP address or user ID) - client_id = _get_client_id(request) - - # Check rate limit - if await _is_rate_limited(client_id): - return JSONResponse( - status_code=429, - content={ - "detail": "Rate limit exceeded", - "retry_after": settings.RATE_LIMIT_WINDOW - } - ) - - # Process request - response = await call_next(request) - - # Update rate limit counter - await _update_rate_limit(client_id) - - return response - -def _get_client_id(request: Request) -> str: - """Get client identifier for rate limiting""" - # Use user ID if authenticated, otherwise use IP - if hasattr(request.state, 'user') and request.state.user: - return f"user:{request.state.user.get('user_id', 'unknown')}" - else: - # Hash IP address for privacy - ip = request.client.host - return f"ip:{hashlib.md5(ip.encode()).hexdigest()}" - -async def _is_rate_limited(client_id: str) -> bool: - """Check if client is rate limited""" - try: - key = f"rate_limit:{client_id}" - current_count = await redis_client.get(key) + def _get_client_id(self, request: Request) -> str: + """Get client identifier""" + # Try to get user ID from state (if authenticated) + if hasattr(request.state, 'user') and request.state.user: + return f"user:{request.state.user.get('user_id', 'unknown')}" - if current_count is None: + # Fall back to IP address + return f"ip:{request.client.host if request.client else 'unknown'}" + + def _is_rate_limited(self, client_id: str) -> bool: + """Check if client is rate limited""" + now = time.time() + minute_ago = now - 60 + + # Get recent requests for this client + if client_id not in self.requests: return False - return int(current_count) >= settings.RATE_LIMIT_REQUESTS + # Filter requests from last minute + recent_requests = [ + req_time for req_time in self.requests[client_id] + if req_time > minute_ago + ] - except Exception as e: - logger.error(f"Rate limit check failed: {e}") - return False - -async def _update_rate_limit(client_id: str): - """Update rate limit counter""" - try: - key = f"rate_limit:{client_id}" + # Update the list + self.requests[client_id] = recent_requests - # Increment counter - current_count = await redis_client.incr(key) + # Check if limit exceeded + return len(recent_requests) >= self.calls_per_minute + + def _record_request(self, client_id: str): + """Record a request for rate limiting""" + now = time.time() - # Set TTL on first request - if current_count == 1: - await redis_client.expire(key, settings.RATE_LIMIT_WINDOW) - - except Exception as e: - logger.error(f"Rate limit update failed: {e}") \ No newline at end of file + if client_id not in self.requests: + self.requests[client_id] = [] + + self.requests[client_id].append(now) + + # Keep only last minute of requests + minute_ago = now - 60 + self.requests[client_id] = [ + req_time for req_time in self.requests[client_id] + if req_time > minute_ago + ]