Files
bakery-ia/gateway/app/middleware/auth.py
2025-07-18 16:48:49 +02:00

122 lines
4.6 KiB
Python

import logging
from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
import httpx
from typing import Optional
import json
from app.core.config import settings
from shared.auth.jwt_handler import JWTHandler
logger = logging.getLogger(__name__)
# JWT handler
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
# Routes that don't require authentication
PUBLIC_ROUTES = [
"/health",
"/metrics",
"/docs",
"/redoc",
"/openapi.json",
"/api/v1/auth/login",
"/api/v1/auth/register",
"/api/v1/auth/refresh",
"/api/v1/auth/verify" # ✅ Add verify to public routes
]
class AuthMiddleware(BaseHTTPMiddleware):
"""Authentication middleware with better error handling"""
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with authentication"""
# Check if route requires authentication
if self._is_public_route(request.url.path):
return await call_next(request)
# Get token from header
token = self._extract_token(request)
if not token:
logger.warning(f"Missing token for {request.url.path}")
return JSONResponse(
status_code=401,
content={"detail": "Authentication required"}
)
# Verify token
try:
# First try to verify token locally
payload = jwt_handler.verify_token(token)
if payload:
# Validate required fields
required_fields = ["user_id", "email", "tenant_id"]
missing_fields = [field for field in required_fields if field not in payload]
if missing_fields:
logger.warning(f"Token missing required fields: {missing_fields}")
return JSONResponse(
status_code=401,
content={"detail": f"Invalid token: missing {missing_fields}"}
)
# Add user info to request state
request.state.user = payload
logger.debug(f"Authenticated user: {payload.get('email')} (tenant: {payload.get('tenant_id')})")
return await call_next(request)
else:
# Token invalid or expired, try auth service verification
logger.info("Local token verification failed, trying auth service")
user_info = await self._verify_with_auth_service(token)
if user_info:
request.state.user = user_info
return await call_next(request)
else:
logger.warning("Auth service verification also failed")
return JSONResponse(
status_code=401,
content={"detail": "Invalid or expired token"}
)
except Exception as e:
logger.error(f"Authentication error: {e}")
return JSONResponse(
status_code=401,
content={"detail": "Authentication failed"}
)
def _is_public_route(self, path: str) -> bool:
"""Check if route is public"""
return any(path.startswith(route) for route in PUBLIC_ROUTES)
def _extract_token(self, request: Request) -> Optional[str]:
"""Extract JWT token from request"""
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.split(" ")[1]
return None
async def _verify_with_auth_service(self, token: str) -> Optional[dict]:
"""Verify token with auth service"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.post(
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
headers={"Authorization": f"Bearer {token}"}
)
if response.status_code == 200:
user_info = response.json()
logger.debug(f"Auth service verification successful: {user_info.get('email')}")
return user_info
else:
logger.warning(f"Auth service verification failed: {response.status_code}")
return None
except Exception as e:
logger.error(f"Auth service verification failed: {e}")
return None