import logging from fastapi import Request from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response import httpx from typing import Optional import json from app.core.config import settings from shared.auth.jwt_handler import JWTHandler logger = logging.getLogger(__name__) # JWT handler jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM) # Routes that don't require authentication PUBLIC_ROUTES = [ "/health", "/metrics", "/docs", "/redoc", "/openapi.json", "/api/v1/auth/login", "/api/v1/auth/register", "/api/v1/auth/refresh", "/api/v1/auth/verify" # ✅ Add verify to public routes ] class AuthMiddleware(BaseHTTPMiddleware): """Authentication middleware with better error handling""" async def dispatch(self, request: Request, call_next) -> Response: """Process request with authentication""" # Check if route requires authentication if self._is_public_route(request.url.path): return await call_next(request) # Get token from header token = self._extract_token(request) if not token: logger.warning(f"Missing token for {request.url.path}") return JSONResponse( status_code=401, content={"detail": "Authentication required"} ) # Verify token try: # First try to verify token locally payload = jwt_handler.verify_token(token) if payload: # Validate required fields required_fields = ["user_id", "email", "tenant_id"] missing_fields = [field for field in required_fields if field not in payload] if missing_fields: logger.warning(f"Token missing required fields: {missing_fields}") return JSONResponse( status_code=401, content={"detail": f"Invalid token: missing {missing_fields}"} ) # Add user info to request state request.state.user = payload logger.debug(f"Authenticated user: {payload.get('email')} (tenant: {payload.get('tenant_id')})") return await call_next(request) else: # Token invalid or expired, try auth service verification logger.info("Local token verification failed, trying auth service") user_info = await self._verify_with_auth_service(token) if user_info: request.state.user = user_info return await call_next(request) else: logger.warning("Auth service verification also failed") return JSONResponse( status_code=401, content={"detail": "Invalid or expired token"} ) except Exception as e: logger.error(f"Authentication error: {e}") return JSONResponse( status_code=401, content={"detail": "Authentication failed"} ) def _is_public_route(self, path: str) -> bool: """Check if route is public""" return any(path.startswith(route) for route in PUBLIC_ROUTES) def _extract_token(self, request: Request) -> Optional[str]: """Extract JWT token from request""" auth_header = request.headers.get("Authorization") if auth_header and auth_header.startswith("Bearer "): return auth_header.split(" ")[1] return None async def _verify_with_auth_service(self, token: str) -> Optional[dict]: """Verify token with auth service""" try: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.post( f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify", headers={"Authorization": f"Bearer {token}"} ) if response.status_code == 200: user_info = response.json() logger.debug(f"Auth service verification successful: {user_info.get('email')}") return user_info else: logger.warning(f"Auth service verification failed: {response.status_code}") return None except Exception as e: logger.error(f"Auth service verification failed: {e}") return None