Files
bakery-ia/gateway/app/middleware/auth.py

285 lines
11 KiB
Python
Raw Normal View History

2025-07-19 17:49:03 +02:00
# gateway/app/middleware/auth.py - IMPROVED VERSION
"""
Enhanced Authentication Middleware for API Gateway
Implements proper token validation and tenant context extraction
"""
import structlog
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
2025-07-17 19:54:04 +02:00
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
import httpx
2025-07-19 17:49:03 +02:00
from typing import Optional, Dict, Any
import asyncio
from app.core.config import settings
from shared.auth.jwt_handler import JWTHandler
2025-07-19 17:49:03 +02:00
logger = structlog.get_logger()
2025-07-19 17:49:03 +02:00
# JWT handler for local token validation
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
# Routes that don't require authentication
PUBLIC_ROUTES = [
"/health",
2025-07-18 16:48:49 +02:00
"/metrics",
"/docs",
"/redoc",
"/openapi.json",
"/api/v1/auth/login",
"/api/v1/auth/register",
2025-07-18 16:48:49 +02:00
"/api/v1/auth/refresh",
2025-07-19 17:49:03 +02:00
"/api/v1/auth/verify"
]
2025-07-17 19:54:04 +02:00
class AuthMiddleware(BaseHTTPMiddleware):
2025-07-19 17:49:03 +02:00
"""
Enhanced Authentication Middleware following microservices best practices
Responsibilities:
1. Token validation (local first, then auth service)
2. User context injection
3. Tenant context extraction (per request)
4. Rate limiting enforcement
5. Request routing decisions
"""
2025-07-19 17:49:03 +02:00
def __init__(self, app, redis_client=None):
super().__init__(app)
self.redis_client = redis_client # For caching and rate limiting
2025-07-17 19:54:04 +02:00
async def dispatch(self, request: Request, call_next) -> Response:
2025-07-19 17:49:03 +02:00
"""Process request with enhanced authentication"""
2025-07-19 17:49:03 +02:00
# Skip authentication for public routes
2025-07-17 19:54:04 +02:00
if self._is_public_route(request.url.path):
return await call_next(request)
2025-07-17 19:54:04 +02:00
2025-07-19 17:49:03 +02:00
# Extract and validate JWT token
2025-07-17 19:54:04 +02:00
token = self._extract_token(request)
if not token:
2025-07-19 17:49:03 +02:00
logger.warning(f"Missing token for protected route: {request.url.path}")
2025-07-17 19:54:04 +02:00
return JSONResponse(
status_code=401,
content={"detail": "Authentication required"}
)
2025-07-19 17:49:03 +02:00
# Verify token and get user context
user_context = await self._verify_token(token)
if not user_context:
logger.warning(f"Invalid token for route: {request.url.path}")
2025-07-17 19:54:04 +02:00
return JSONResponse(
status_code=401,
2025-07-19 17:49:03 +02:00
content={"detail": "Invalid or expired token"}
2025-07-17 19:54:04 +02:00
)
2025-07-19 17:49:03 +02:00
# Extract tenant context from request (not from JWT)
tenant_id = self._extract_tenant_from_request(request)
# Verify user has access to tenant (if tenant_id provided)
if tenant_id:
has_access = await self._verify_tenant_access(user_context["user_id"], tenant_id)
if not has_access:
logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}")
return JSONResponse(
status_code=403,
content={"detail": "Access denied to tenant"}
)
request.state.tenant_id = tenant_id
# Inject user context into request
request.state.user = user_context
request.state.authenticated = True
# Add user context to forwarded requests
self._inject_auth_headers(request, user_context, tenant_id)
logger.debug(f"Authenticated request: {user_context['email']} -> {request.url.path}")
return await call_next(request)
2025-07-17 19:54:04 +02:00
def _is_public_route(self, path: str) -> bool:
2025-07-19 17:49:03 +02:00
"""Check if route requires authentication"""
2025-07-17 19:54:04 +02:00
return any(path.startswith(route) for route in PUBLIC_ROUTES)
2025-07-17 19:54:04 +02:00
def _extract_token(self, request: Request) -> Optional[str]:
2025-07-19 17:49:03 +02:00
"""Extract JWT token from Authorization header"""
2025-07-17 19:54:04 +02:00
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.split(" ")[1]
return None
2025-07-19 17:49:03 +02:00
def _extract_tenant_from_request(self, request: Request) -> Optional[str]:
"""
Extract tenant ID from request (NOT from JWT token)
Priority order:
1. X-Tenant-ID header
2. tenant_id query parameter
3. tenant_id in request path
"""
# Method 1: Header
tenant_id = request.headers.get("X-Tenant-ID")
if tenant_id:
return tenant_id
# Method 2: Query parameter
tenant_id = request.query_params.get("tenant_id")
if tenant_id:
return tenant_id
# Method 3: Path parameter (extract from URLs like /api/v1/tenants/{tenant_id}/...)
path_parts = request.url.path.split("/")
if "tenants" in path_parts:
try:
tenant_index = path_parts.index("tenants")
if tenant_index + 1 < len(path_parts):
return path_parts[tenant_index + 1]
except (ValueError, IndexError):
pass
return None
async def _verify_token(self, token: str) -> Optional[Dict[str, Any]]:
"""
Verify JWT token with fallback strategy:
1. Local validation (fast)
2. Auth service validation (authoritative)
3. Cache valid tokens to reduce auth service calls
"""
# Step 1: Try local JWT validation first (fast)
try:
payload = jwt_handler.verify_token(token)
if payload and self._validate_token_payload(payload):
logger.debug("Token validated locally")
return payload
except Exception as e:
logger.debug(f"Local token validation failed: {e}")
# Step 2: Check cache for recently validated tokens
if self.redis_client:
try:
cached_user = await self._get_cached_user(token)
if cached_user:
logger.debug("Token found in cache")
return cached_user
except Exception as e:
logger.warning(f"Cache lookup failed: {e}")
# Step 3: Verify with auth service (authoritative)
try:
user_context = await self._verify_with_auth_service(token)
if user_context:
# Cache successful validation
if self.redis_client:
await self._cache_user(token, user_context)
logger.debug("Token validated by auth service")
return user_context
except Exception as e:
logger.error(f"Auth service validation failed: {e}")
return None
def _validate_token_payload(self, payload: Dict[str, Any]) -> bool:
"""Validate JWT payload has required fields"""
required_fields = ["user_id", "email", "exp"]
return all(field in payload for field in required_fields)
async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]:
2025-07-17 19:54:04 +02:00
"""Verify token with auth service"""
try:
2025-07-19 17:49:03 +02:00
async with httpx.AsyncClient(timeout=3.0) as client:
2025-07-17 19:54:04 +02:00
response = await client.post(
2025-07-18 16:48:49 +02:00
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
2025-07-17 19:54:04 +02:00
headers={"Authorization": f"Bearer {token}"}
)
2025-07-17 19:54:04 +02:00
if response.status_code == 200:
2025-07-19 17:49:03 +02:00
return response.json()
2025-07-17 19:54:04 +02:00
else:
2025-07-19 17:49:03 +02:00
logger.warning(f"Auth service returned {response.status_code}")
2025-07-17 19:54:04 +02:00
return None
2025-07-19 17:49:03 +02:00
except asyncio.TimeoutError:
logger.error("Auth service timeout")
return None
except Exception as e:
logger.error(f"Auth service error: {e}")
return None
async def _verify_tenant_access(self, user_id: str, tenant_id: str) -> bool:
"""Verify user has access to specific tenant"""
try:
async with httpx.AsyncClient(timeout=3.0) as client:
response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/access/{user_id}"
)
return response.status_code == 200
2025-07-17 19:54:04 +02:00
except Exception as e:
2025-07-19 17:49:03 +02:00
logger.error(f"Tenant access verification failed: {e}")
return False
async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]:
"""Get user context from cache"""
if not self.redis_client:
return None
cache_key = f"auth:token:{hash(token)}"
cached_data = await self.redis_client.get(cache_key)
if cached_data:
import json
return json.loads(cached_data)
return None
async def _cache_user(self, token: str, user_context: Dict[str, Any], ttl: int = 300):
"""Cache user context for 5 minutes"""
if not self.redis_client:
return
cache_key = f"auth:token:{hash(token)}"
import json
await self.redis_client.setex(cache_key, ttl, json.dumps(user_context))
def _inject_auth_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str]):
2025-07-20 07:24:04 +02:00
"""
Inject authentication headers for downstream services
This allows services to work both:
1. Behind the gateway (using request.state)
2. Called directly (using headers) for development/testing
"""
# Remove any existing auth headers to prevent spoofing
headers_to_remove = [
"x-user-id", "x-user-email", "x-user-role",
"x-tenant-id", "x-user-permissions", "x-authenticated"
]
for header in headers_to_remove:
request.headers.__dict__["_list"] = [
(k, v) for k, v in request.headers.raw
if k.lower() != header.lower()
]
# Inject new headers
new_headers = [
(b"x-authenticated", b"true"),
(b"x-user-id", str(user_context.get("user_id", "")).encode()),
(b"x-user-email", str(user_context.get("email", "")).encode()),
(b"x-user-role", str(user_context.get("role", "user")).encode()),
]
2025-07-19 17:49:03 +02:00
if tenant_id:
2025-07-20 07:24:04 +02:00
new_headers.append((b"x-tenant-id", tenant_id.encode()))
permissions = user_context.get("permissions", [])
if permissions:
new_headers.append((b"x-user-permissions", ",".join(permissions).encode()))
# Add headers to request
request.headers.__dict__["_list"].extend(new_headers)
logger.debug(f"Injected auth headers for user {user_context.get('email')}")