Files
bakery-ia/gateway/app/middleware/auth.py
2025-07-26 18:46:52 +02:00

256 lines
9.8 KiB
Python

# gateway/app/middleware/auth.py
"""
Enhanced Authentication Middleware for API Gateway with Tenant Access Control
"""
import structlog
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from typing import Optional, Dict, Any
from app.core.config import settings
from shared.auth.jwt_handler import JWTHandler
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
logger = structlog.get_logger()
# JWT handler for local token validation
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
# Routes that don't require authentication
PUBLIC_ROUTES = [
"/health",
"/metrics",
"/docs",
"/redoc",
"/openapi.json",
"/api/v1/auth/login",
"/api/v1/auth/register",
"/api/v1/auth/refresh",
"/api/v1/auth/verify",
"/api/v1/nominatim/search"
]
class AuthMiddleware(BaseHTTPMiddleware):
"""
Enhanced Authentication Middleware with Tenant Access Control
"""
def __init__(self, app, redis_client=None):
super().__init__(app)
self.redis_client = redis_client # For caching and rate limiting
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with enhanced authentication and tenant access control"""
# Skip authentication for OPTIONS requests (CORS preflight)
if request.method == "OPTIONS":
return await call_next(request)
# Skip authentication for public routes
if self._is_public_route(request.url.path):
return await call_next(request)
# ✅ STEP 1: Extract and validate JWT token
token = self._extract_token(request)
if not token:
logger.warning(f"Missing token for protected route: {request.url.path}")
return JSONResponse(
status_code=401,
content={"detail": "Authentication required"}
)
# ✅ STEP 2: Verify token and get user context
# Pass self.redis_client to _verify_token to enable caching
user_context = await self._verify_token(token)
if not user_context:
logger.warning(f"Invalid token for route: {request.url.path}")
return JSONResponse(
status_code=401,
content={"detail": "Invalid or expired token"}
)
# ✅ STEP 3: Extract tenant context from URL using shared utility
tenant_id = extract_tenant_id_from_path(request.url.path)
# ✅ STEP 4: Verify tenant access if this is a tenant-scoped route
if tenant_id and is_tenant_scoped_path(request.url.path):
# Use TenantAccessManager for gateway-level verification with caching
# Ensure tenant_access_manager uses the redis_client from the middleware
if self.redis_client and tenant_access_manager.redis_client is None:
tenant_access_manager.redis_client = self.redis_client
has_access = await tenant_access_manager.verify_basic_tenant_access( # Corrected method call
user_context["user_id"],
tenant_id
)
if not has_access:
logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}")
return JSONResponse(
status_code=403,
content={"detail": f"Access denied to tenant {tenant_id}"}
)
# Set tenant context in request state
request.state.tenant_id = tenant_id
request.state.tenant_verified = True
logger.debug(f"Tenant access verified",
user_id=user_context["user_id"],
tenant_id=tenant_id,
path=request.url.path)
# ✅ STEP 5: Inject user context into request
request.state.user = user_context
request.state.authenticated = True
# ✅ STEP 6: Add context headers for downstream services
self._inject_context_headers(request, user_context, tenant_id)
logger.debug(f"Authenticated request",
user_email=user_context['email'],
tenant_id=tenant_id,
path=request.url.path)
return await call_next(request)
def _is_public_route(self, path: str) -> bool:
"""Check if route requires authentication"""
return any(path.startswith(route) for route in PUBLIC_ROUTES)
def _extract_token(self, request: Request) -> Optional[str]:
"""Extract JWT token from Authorization header"""
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.split(" ")[1]
return None
async def _verify_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Verify JWT token with fallback strategy"""
# Try local JWT validation first (fast)
try:
payload = jwt_handler.verify_token(token)
if payload and self._validate_token_payload(payload):
logger.debug("Token validated locally")
return payload
except Exception as e:
logger.debug(f"Local token validation failed: {e}")
# Check cache for recently validated tokens
if self.redis_client:
try:
cached_user = await self._get_cached_user(token)
if cached_user:
logger.debug("Token found in cache")
return cached_user
except Exception as e:
logger.warning(f"Cache lookup failed: {e}")
# Verify with auth service (authoritative)
try:
user_context = await self._verify_with_auth_service(token)
if user_context:
# Cache successful validation
if self.redis_client:
await self._cache_user(token, user_context)
logger.debug("Token validated by auth service")
return user_context
except Exception as e:
logger.error(f"Auth service validation failed: {e}")
return None
def _validate_token_payload(self, payload: Dict[str, Any]) -> bool:
"""Validate JWT payload has required fields"""
required_fields = ["user_id", "email", "exp"]
return all(field in payload for field in required_fields)
async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]:
"""Verify token with auth service"""
try:
import httpx
async with httpx.AsyncClient(timeout=3.0) as client:
response = await client.post(
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
headers={"Authorization": f"Bearer {token}"}
)
if response.status_code == 200:
return response.json()
else:
logger.warning(f"Auth service returned {response.status_code}")
return None
except Exception as e:
logger.error(f"Auth service error: {e}")
return None
async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]:
"""Get user context from cache"""
if not self.redis_client:
return None
cache_key = f"auth:token:{hash(token)}"
try:
cached_data = await self.redis_client.get(cache_key)
if cached_data:
import json
if isinstance(cached_data, bytes):
cached_data = cached_data.decode()
return json.loads(cached_data)
except Exception as e:
logger.warning(f"Cache get failed: {e}")
return None
async def _cache_user(self, token: str, user_context: Dict[str, Any], ttl: int = 300):
"""Cache user context for 5 minutes"""
if not self.redis_client:
return
cache_key = f"auth:token:{hash(token)}"
try:
import json
await self.redis_client.setex(cache_key, ttl, json.dumps(user_context))
except Exception as e:
logger.warning(f"Cache set failed: {e}")
def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str]):
"""Inject authentication and tenant headers for downstream services"""
# Remove any existing auth headers to prevent spoofing
headers_to_remove = [
"x-user-id", "x-user-email", "x-user-role",
"x-tenant-id", "x-tenant-verified", "x-authenticated"
]
for header in headers_to_remove:
request.headers.__dict__["_list"] = [
(k, v) for k, v in request.headers.raw
if k.lower() != header.lower()
]
# Inject new headers
new_headers = [
(b"x-authenticated", b"true"),
(b"x-user-id", str(user_context.get("user_id", "")).encode()),
(b"x-user-email", str(user_context.get("email", "")).encode()),
(b"x-user-role", str(user_context.get("role", "user")).encode()),
]
# Add tenant context if verified
if tenant_id:
new_headers.extend([
(b"x-tenant-id", tenant_id.encode()),
(b"x-tenant-verified", b"true")
])
# Add headers to request
request.headers.__dict__["_list"].extend(new_headers)
logger.debug(f"Injected context headers",
user_id=user_context.get("user_id"),
tenant_id=tenant_id)