# gateway/app/middleware/auth.py - IMPROVED VERSION """ Enhanced Authentication Middleware for API Gateway Implements proper token validation and tenant context extraction """ import structlog from fastapi import Request, HTTPException from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response import httpx from typing import Optional, Dict, Any import asyncio from app.core.config import settings from shared.auth.jwt_handler import JWTHandler logger = structlog.get_logger() # JWT handler for local token validation jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM) # Routes that don't require authentication PUBLIC_ROUTES = [ "/health", "/metrics", "/docs", "/redoc", "/openapi.json", "/api/v1/auth/login", "/api/v1/auth/register", "/api/v1/auth/refresh", "/api/v1/auth/verify" ] class AuthMiddleware(BaseHTTPMiddleware): """ Enhanced Authentication Middleware following microservices best practices Responsibilities: 1. Token validation (local first, then auth service) 2. User context injection 3. Tenant context extraction (per request) 4. Rate limiting enforcement 5. Request routing decisions """ def __init__(self, app, redis_client=None): super().__init__(app) self.redis_client = redis_client # For caching and rate limiting async def dispatch(self, request: Request, call_next) -> Response: """Process request with enhanced authentication""" # Skip authentication for public routes if self._is_public_route(request.url.path): return await call_next(request) # Extract and validate JWT token token = self._extract_token(request) if not token: logger.warning(f"Missing token for protected route: {request.url.path}") return JSONResponse( status_code=401, content={"detail": "Authentication required"} ) # Verify token and get user context user_context = await self._verify_token(token) if not user_context: logger.warning(f"Invalid token for route: {request.url.path}") return JSONResponse( status_code=401, content={"detail": "Invalid or expired token"} ) # Extract tenant context from request (not from JWT) tenant_id = self._extract_tenant_from_request(request) # Verify user has access to tenant (if tenant_id provided) if tenant_id: has_access = await self._verify_tenant_access(user_context["user_id"], tenant_id) if not has_access: logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}") return JSONResponse( status_code=403, content={"detail": "Access denied to tenant"} ) request.state.tenant_id = tenant_id # Inject user context into request request.state.user = user_context request.state.authenticated = True # Add user context to forwarded requests self._inject_auth_headers(request, user_context, tenant_id) logger.debug(f"Authenticated request: {user_context['email']} -> {request.url.path}") return await call_next(request) def _is_public_route(self, path: str) -> bool: """Check if route requires authentication""" return any(path.startswith(route) for route in PUBLIC_ROUTES) def _extract_token(self, request: Request) -> Optional[str]: """Extract JWT token from Authorization header""" auth_header = request.headers.get("Authorization") if auth_header and auth_header.startswith("Bearer "): return auth_header.split(" ")[1] return None def _extract_tenant_from_request(self, request: Request) -> Optional[str]: """ Extract tenant ID from request (NOT from JWT token) Priority order: 1. X-Tenant-ID header 2. tenant_id query parameter 3. tenant_id in request path """ # Method 1: Header tenant_id = request.headers.get("X-Tenant-ID") if tenant_id: return tenant_id # Method 2: Query parameter tenant_id = request.query_params.get("tenant_id") if tenant_id: return tenant_id # Method 3: Path parameter (extract from URLs like /api/v1/tenants/{tenant_id}/...) path_parts = request.url.path.split("/") if "tenants" in path_parts: try: tenant_index = path_parts.index("tenants") if tenant_index + 1 < len(path_parts): return path_parts[tenant_index + 1] except (ValueError, IndexError): pass return None async def _verify_token(self, token: str) -> Optional[Dict[str, Any]]: """ Verify JWT token with fallback strategy: 1. Local validation (fast) 2. Auth service validation (authoritative) 3. Cache valid tokens to reduce auth service calls """ # Step 1: Try local JWT validation first (fast) try: payload = jwt_handler.verify_token(token) if payload and self._validate_token_payload(payload): logger.debug("Token validated locally") return payload except Exception as e: logger.debug(f"Local token validation failed: {e}") # Step 2: Check cache for recently validated tokens if self.redis_client: try: cached_user = await self._get_cached_user(token) if cached_user: logger.debug("Token found in cache") return cached_user except Exception as e: logger.warning(f"Cache lookup failed: {e}") # Step 3: Verify with auth service (authoritative) try: user_context = await self._verify_with_auth_service(token) if user_context: # Cache successful validation if self.redis_client: await self._cache_user(token, user_context) logger.debug("Token validated by auth service") return user_context except Exception as e: logger.error(f"Auth service validation failed: {e}") return None def _validate_token_payload(self, payload: Dict[str, Any]) -> bool: """Validate JWT payload has required fields""" required_fields = ["user_id", "email", "exp"] return all(field in payload for field in required_fields) async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]: """Verify token with auth service""" try: async with httpx.AsyncClient(timeout=3.0) as client: response = await client.post( f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify", headers={"Authorization": f"Bearer {token}"} ) if response.status_code == 200: return response.json() else: logger.warning(f"Auth service returned {response.status_code}") return None except asyncio.TimeoutError: logger.error("Auth service timeout") return None except Exception as e: logger.error(f"Auth service error: {e}") return None async def _verify_tenant_access(self, user_id: str, tenant_id: str) -> bool: """Verify user has access to specific tenant""" try: async with httpx.AsyncClient(timeout=3.0) as client: response = await client.get( f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/access/{user_id}" ) return response.status_code == 200 except Exception as e: logger.error(f"Tenant access verification failed: {e}") return False async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]: """Get user context from cache""" if not self.redis_client: return None cache_key = f"auth:token:{hash(token)}" cached_data = await self.redis_client.get(cache_key) if cached_data: import json return json.loads(cached_data) return None async def _cache_user(self, token: str, user_context: Dict[str, Any], ttl: int = 300): """Cache user context for 5 minutes""" if not self.redis_client: return cache_key = f"auth:token:{hash(token)}" import json await self.redis_client.setex(cache_key, ttl, json.dumps(user_context)) def _inject_auth_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str]): """ Inject authentication headers for downstream services This allows services to work both: 1. Behind the gateway (using request.state) 2. Called directly (using headers) for development/testing """ # Remove any existing auth headers to prevent spoofing headers_to_remove = [ "x-user-id", "x-user-email", "x-user-role", "x-tenant-id", "x-user-permissions", "x-authenticated" ] for header in headers_to_remove: request.headers.__dict__["_list"] = [ (k, v) for k, v in request.headers.raw if k.lower() != header.lower() ] # Inject new headers new_headers = [ (b"x-authenticated", b"true"), (b"x-user-id", str(user_context.get("user_id", "")).encode()), (b"x-user-email", str(user_context.get("email", "")).encode()), (b"x-user-role", str(user_context.get("role", "user")).encode()), ] if tenant_id: new_headers.append((b"x-tenant-id", tenant_id.encode())) permissions = user_context.get("permissions", []) if permissions: new_headers.append((b"x-user-permissions", ",".join(permissions).encode())) # Add headers to request request.headers.__dict__["_list"].extend(new_headers) logger.debug(f"Injected auth headers for user {user_context.get('email')}")