# ================================================================ # gateway/app/routes/auth.py # ================================================================ """ Authentication and User Management Routes for API Gateway Unified proxy to auth microservice """ import logging import httpx from fastapi import APIRouter, Request, Response, HTTPException, status from fastapi.responses import JSONResponse from typing import Dict, Any from app.core.config import settings from app.core.header_manager import header_manager from app.core.service_discovery import ServiceDiscovery from shared.monitoring.metrics import MetricsCollector logger = logging.getLogger(__name__) router = APIRouter() # Initialize service discovery and metrics service_discovery = ServiceDiscovery() metrics = MetricsCollector("gateway") # Auth service configuration AUTH_SERVICE_URL = settings.AUTH_SERVICE_URL or "http://auth-service:8000" class AuthProxy: """Authentication service proxy with enhanced error handling""" def __init__(self): self.client = httpx.AsyncClient( timeout=httpx.Timeout(30.0), limits=httpx.Limits(max_connections=100, max_keepalive_connections=20) ) async def forward_request( self, method: str, path: str, request: Request ) -> Response: """Forward request to auth service with proper error handling""" # Handle OPTIONS requests directly for CORS if request.method == "OPTIONS": return Response( status_code=200, headers={ "Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST, "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", "Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID", "Access-Control-Allow-Credentials": "true", "Access-Control-Max-Age": "86400" # Cache preflight for 24 hours } ) try: # Get auth service URL (with service discovery if available) auth_url = await self._get_auth_service_url() target_url = f"{auth_url}/{path}" # Prepare headers (remove hop-by-hop headers) # IMPORTANT: Use request.headers directly to get headers added by middleware # Also check request.state for headers injected by middleware headers = self._prepare_headers(request.headers, request) # Get request body body = await request.body() # Forward request logger.info(f"Forwarding {method} /{path} to auth service") response = await self.client.request( method=method, url=target_url, headers=headers, content=body, params=dict(request.query_params) ) # Record metrics metrics.increment_counter("gateway_auth_requests_total") metrics.increment_counter( "gateway_auth_responses_total", labels={"status_code": str(response.status_code)} ) # Prepare response headers response_headers = self._prepare_response_headers(dict(response.headers)) return Response( content=response.content, status_code=response.status_code, headers=response_headers, media_type=response.headers.get("content-type") ) except httpx.TimeoutException: logger.error(f"Timeout forwarding {method} /{path} to auth service") metrics.increment_counter("gateway_auth_errors_total", labels={"error": "timeout"}) raise HTTPException( status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Authentication service timeout" ) except httpx.ConnectError: logger.error(f"Connection error forwarding {method} /{path} to auth service") metrics.increment_counter("gateway_auth_errors_total", labels={"error": "connection"}) raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Authentication service unavailable" ) except Exception as e: logger.error(f"Error forwarding {method} /{path} to auth service: {e}") metrics.increment_counter("gateway_auth_errors_total", labels={"error": "unknown"}) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal gateway error" ) async def _get_auth_service_url(self) -> str: """Get auth service URL with service discovery""" try: # Try service discovery first service_url = await service_discovery.get_service_url("auth-service") if service_url: return service_url except Exception as e: logger.warning(f"Service discovery failed: {e}") # Fall back to configured URL return AUTH_SERVICE_URL def _prepare_headers(self, headers, request=None) -> Dict[str, str]: """Prepare headers for forwarding using unified HeaderManager""" # Use unified HeaderManager to get all headers if request: all_headers = header_manager.get_all_headers_for_proxy(request) logger.debug(f"DEBUG: Added headers from HeaderManager: {list(all_headers.keys())}") else: # Fallback: convert headers to dict manually all_headers = {} if hasattr(headers, '_list'): for k, v in headers.__dict__.get('_list', []): key = k.decode() if isinstance(k, bytes) else k value = v.decode() if isinstance(v, bytes) else v all_headers[key] = value elif hasattr(headers, 'raw'): for k, v in headers.raw: key = k.decode() if isinstance(k, bytes) else k value = v.decode() if isinstance(v, bytes) else v all_headers[key] = value else: # Headers is already a dict all_headers = dict(headers) # Debug logging logger.info(f"📤 Forwarding headers - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}") return all_headers def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]: """Prepare response headers""" # Remove server-specific headers filtered_headers = { k: v for k, v in headers.items() if k.lower() not in {'server', 'date'} } # Add CORS headers if needed if settings.CORS_ORIGINS: filtered_headers['Access-Control-Allow-Origin'] = '*' filtered_headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS' filtered_headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' return filtered_headers # Initialize proxy auth_proxy = AuthProxy() # ================================================================ # CATCH-ALL ROUTE for all auth and user endpoints # ================================================================ @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) async def proxy_auth_requests(path: str, request: Request): """Catch-all proxy for all auth and user requests""" return await auth_proxy.forward_request(request.method, f"api/v1/auth/{path}", request) # ================================================================ # HEALTH CHECK for auth service # ================================================================ @router.get("/health") async def auth_service_health(): """Check auth service health""" try: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(f"{AUTH_SERVICE_URL}/health") if response.status_code == 200: return { "status": "healthy", "auth_service": "available", "response_time_ms": response.elapsed.total_seconds() * 1000 } else: return { "status": "unhealthy", "auth_service": "error", "status_code": response.status_code } except Exception as e: logger.error(f"Auth service health check failed: {e}") return { "status": "unhealthy", "auth_service": "unavailable", "error": str(e) } # ================================================================ # CLEANUP # ================================================================ @router.on_event("shutdown") async def cleanup(): """Cleanup resources""" await auth_proxy.client.aclose()