# ================================================================ # gateway/app/routes/user.py # ================================================================ """ Authentication routes for API Gateway """ import logging import httpx from fastapi import APIRouter, Request, Response, HTTPException, status from fastapi.responses import JSONResponse from typing import Dict, Any import json from app.core.config import settings from app.core.service_discovery import ServiceDiscovery from shared.monitoring.metrics import MetricsCollector logger = logging.getLogger(__name__) router = APIRouter() # Initialize service discovery and metrics service_discovery = ServiceDiscovery() metrics = MetricsCollector("gateway") # Auth service configuration AUTH_SERVICE_URL = settings.AUTH_SERVICE_URL or "http://auth-service:8000" class UserProxy: """Authentication service proxy with enhanced error handling""" def __init__(self): self.client = httpx.AsyncClient( timeout=httpx.Timeout(30.0), limits=httpx.Limits(max_connections=100, max_keepalive_connections=20) ) async def forward_request( self, method: str, path: str, request: Request ) -> Response: """Forward request to auth service with proper error handling""" # Handle OPTIONS requests directly for CORS if request.method == "OPTIONS": return Response( status_code=200, headers={ "Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST, "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", "Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID", "Access-Control-Allow-Credentials": "true", "Access-Control-Max-Age": "86400" # Cache preflight for 24 hours } ) try: # Get auth service URL (with service discovery if available) auth_url = await self._get_auth_service_url() # FIX: Auth service uses /api/v1/auth/ prefix, not /api/v1/users/ target_url = f"{auth_url}/api/v1/auth/{path}" # Prepare headers (remove hop-by-hop headers) # IMPORTANT: Use request.headers directly to get headers added by middleware # Also check request.state for headers injected by middleware headers = self._prepare_headers(request.headers, request) # Get request body body = await request.body() # Forward request logger.info(f"Forwarding {method} {path} to auth service") response = await self.client.request( method=method, url=target_url, headers=headers, content=body, params=dict(request.query_params) ) # Record metrics metrics.increment_counter("gateway_auth_requests_total") metrics.increment_counter( "gateway_auth_responses_total", labels={"status_code": str(response.status_code)} ) # Prepare response headers response_headers = self._prepare_response_headers(dict(response.headers)) return Response( content=response.content, status_code=response.status_code, headers=response_headers, media_type=response.headers.get("content-type") ) except httpx.TimeoutException: logger.error(f"Timeout forwarding {method} {path} to auth service") metrics.increment_counter("gateway_auth_errors_total", labels={"error": "timeout"}) raise HTTPException( status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Authentication service timeout" ) except httpx.ConnectError: logger.error(f"Connection error forwarding {method} {path} to auth service") metrics.increment_counter("gateway_auth_errors_total", labels={"error": "connection"}) raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Authentication service unavailable" ) except Exception as e: logger.error(f"Error forwarding {method} {path} to auth service: {e}") metrics.increment_counter("gateway_auth_errors_total", labels={"error": "unknown"}) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal gateway error" ) async def _get_auth_service_url(self) -> str: """Get auth service URL with service discovery""" try: # Try service discovery first service_url = await service_discovery.get_service_url("auth-service") if service_url: return service_url except Exception as e: logger.warning(f"Service discovery failed: {e}") # Fall back to configured URL return AUTH_SERVICE_URL def _prepare_headers(self, headers, request=None) -> Dict[str, str]: """Prepare headers for forwarding (remove hop-by-hop headers)""" # Remove hop-by-hop headers hop_by_hop_headers = { 'connection', 'keep-alive', 'proxy-authenticate', 'proxy-authorization', 'te', 'trailers', 'upgrade' } # Convert headers to dict if it's a Headers object # This ensures we get ALL headers including those added by middleware if hasattr(headers, '_list'): # Get headers from the _list where middleware adds them all_headers_list = headers.__dict__.get('_list', []) # Convert to dict for easier processing all_headers = {} for k, v in all_headers_list: key = k.decode() if isinstance(k, bytes) else k value = v.decode() if isinstance(v, bytes) else v all_headers[key] = value # Check if headers are missing and try to get them from request.state if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'): # Add missing headers from request.state if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers: all_headers['x-user-id'] = request.state.injected_headers['x-user-id'] if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers: all_headers['x-user-email'] = request.state.injected_headers['x-user-email'] if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers: all_headers['x-user-role'] = request.state.injected_headers['x-user-role'] # Add is_demo flag if this is a demo session if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session: all_headers['x-is-demo'] = 'true' # Filter out hop-by-hop headers filtered_headers = { k: v for k, v in all_headers.items() if k.lower() not in hop_by_hop_headers } elif hasattr(headers, 'raw'): # FastAPI/Starlette Headers object - use raw to get all headers filtered_headers = { k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v for k, v in headers.raw if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers } else: # Already a dict filtered_headers = { k: v for k, v in headers.items() if k.lower() not in hop_by_hop_headers } # Add gateway identifier filtered_headers['X-Forwarded-By'] = 'bakery-gateway' filtered_headers['X-Gateway-Version'] = '1.0.0' return filtered_headers def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]: """Prepare response headers""" # Remove server-specific headers filtered_headers = { k: v for k, v in headers.items() if k.lower() not in {'server', 'date'} } # Add CORS headers if needed if settings.CORS_ORIGINS: filtered_headers['Access-Control-Allow-Origin'] = '*' filtered_headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS' filtered_headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' return filtered_headers # Initialize proxy user_proxy = UserProxy() # ================================================================ # USER MANAGEMENT ENDPOINTS - Proxied to auth service # ================================================================ @router.get("/me") async def get_current_user(request: Request): """Proxy get current user to auth service""" return await user_proxy.forward_request("GET", "me", request) @router.put("/me") async def update_current_user(request: Request): """Proxy update current user to auth service""" return await user_proxy.forward_request("PUT", "me", request) @router.get("/delete/{user_id}/deletion-preview") async def preview_user_deletion(user_id: str, request: Request): """Proxy user deletion preview to auth service""" return await user_proxy.forward_request("GET", f"delete/{user_id}/deletion-preview", request) @router.delete("/delete/{user_id}") async def delete_user(user_id: str, request: Request): """Proxy admin user deletion to auth service""" return await user_proxy.forward_request("DELETE", f"delete/{user_id}", request) # ================================================================ # CATCH-ALL ROUTE for any other user endpoints # ================================================================ @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) async def proxy_auth_requests(path: str, request: Request): """Catch-all proxy for auth requests""" return await user_proxy.forward_request(request.method, path, request)