2025-07-26 18:46:52 +02:00
|
|
|
# gateway/app/middleware/auth.py
|
2025-07-19 17:49:03 +02:00
|
|
|
"""
|
2025-07-26 18:46:52 +02:00
|
|
|
Enhanced Authentication Middleware for API Gateway with Tenant Access Control
|
2025-07-19 17:49:03 +02:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import structlog
|
|
|
|
|
from fastapi import Request, HTTPException
|
2025-07-17 13:09:24 +02:00
|
|
|
from fastapi.responses import JSONResponse
|
2025-07-17 19:54:04 +02:00
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
|
from starlette.responses import Response
|
2025-07-19 17:49:03 +02:00
|
|
|
from typing import Optional, Dict, Any
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
from app.core.config import settings
|
|
|
|
|
from shared.auth.jwt_handler import JWTHandler
|
2025-07-26 18:46:52 +02:00
|
|
|
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
logger = structlog.get_logger()
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
# JWT handler for local token validation
|
2025-07-17 13:09:24 +02:00
|
|
|
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
|
|
|
|
|
|
|
|
|
# Routes that don't require authentication
|
|
|
|
|
PUBLIC_ROUTES = [
|
|
|
|
|
"/health",
|
2025-07-18 16:48:49 +02:00
|
|
|
"/metrics",
|
2025-07-17 13:09:24 +02:00
|
|
|
"/docs",
|
|
|
|
|
"/redoc",
|
|
|
|
|
"/openapi.json",
|
|
|
|
|
"/api/v1/auth/login",
|
|
|
|
|
"/api/v1/auth/register",
|
2025-07-18 16:48:49 +02:00
|
|
|
"/api/v1/auth/refresh",
|
2025-07-22 17:01:12 +02:00
|
|
|
"/api/v1/auth/verify",
|
|
|
|
|
"/api/v1/nominatim/search"
|
2025-07-17 13:09:24 +02:00
|
|
|
]
|
|
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
class AuthMiddleware(BaseHTTPMiddleware):
|
2025-07-19 17:49:03 +02:00
|
|
|
"""
|
2025-07-26 18:46:52 +02:00
|
|
|
Enhanced Authentication Middleware with Tenant Access Control
|
2025-07-19 17:49:03 +02:00
|
|
|
"""
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
def __init__(self, app, redis_client=None):
|
|
|
|
|
super().__init__(app)
|
|
|
|
|
self.redis_client = redis_client # For caching and rate limiting
|
|
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
2025-07-26 18:46:52 +02:00
|
|
|
"""Process request with enhanced authentication and tenant access control"""
|
|
|
|
|
|
2025-07-22 23:01:34 +02:00
|
|
|
# Skip authentication for OPTIONS requests (CORS preflight)
|
|
|
|
|
if request.method == "OPTIONS":
|
|
|
|
|
return await call_next(request)
|
2025-07-26 18:46:52 +02:00
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
# Skip authentication for public routes
|
2025-07-17 19:54:04 +02:00
|
|
|
if self._is_public_route(request.url.path):
|
2025-07-17 13:09:24 +02:00
|
|
|
return await call_next(request)
|
2025-07-26 18:46:52 +02:00
|
|
|
|
|
|
|
|
# ✅ STEP 1: Extract and validate JWT token
|
2025-07-17 19:54:04 +02:00
|
|
|
token = self._extract_token(request)
|
|
|
|
|
if not token:
|
2025-07-19 17:49:03 +02:00
|
|
|
logger.warning(f"Missing token for protected route: {request.url.path}")
|
2025-07-17 19:54:04 +02:00
|
|
|
return JSONResponse(
|
|
|
|
|
status_code=401,
|
|
|
|
|
content={"detail": "Authentication required"}
|
|
|
|
|
)
|
2025-07-26 18:46:52 +02:00
|
|
|
|
|
|
|
|
# ✅ STEP 2: Verify token and get user context
|
|
|
|
|
# Pass self.redis_client to _verify_token to enable caching
|
2025-07-19 17:49:03 +02:00
|
|
|
user_context = await self._verify_token(token)
|
|
|
|
|
if not user_context:
|
|
|
|
|
logger.warning(f"Invalid token for route: {request.url.path}")
|
2025-07-17 19:54:04 +02:00
|
|
|
return JSONResponse(
|
|
|
|
|
status_code=401,
|
2025-07-19 17:49:03 +02:00
|
|
|
content={"detail": "Invalid or expired token"}
|
2025-07-17 19:54:04 +02:00
|
|
|
)
|
2025-07-26 18:46:52 +02:00
|
|
|
|
|
|
|
|
# ✅ STEP 3: Extract tenant context from URL using shared utility
|
|
|
|
|
tenant_id = extract_tenant_id_from_path(request.url.path)
|
|
|
|
|
|
|
|
|
|
# ✅ STEP 4: Verify tenant access if this is a tenant-scoped route
|
|
|
|
|
if tenant_id and is_tenant_scoped_path(request.url.path):
|
|
|
|
|
# Use TenantAccessManager for gateway-level verification with caching
|
|
|
|
|
# Ensure tenant_access_manager uses the redis_client from the middleware
|
|
|
|
|
if self.redis_client and tenant_access_manager.redis_client is None:
|
|
|
|
|
tenant_access_manager.redis_client = self.redis_client
|
|
|
|
|
|
|
|
|
|
has_access = await tenant_access_manager.verify_basic_tenant_access( # Corrected method call
|
|
|
|
|
user_context["user_id"],
|
|
|
|
|
tenant_id
|
|
|
|
|
)
|
|
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
if not has_access:
|
|
|
|
|
logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}")
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
status_code=403,
|
2025-07-26 18:46:52 +02:00
|
|
|
content={"detail": f"Access denied to tenant {tenant_id}"}
|
2025-07-19 17:49:03 +02:00
|
|
|
)
|
2025-07-26 18:46:52 +02:00
|
|
|
|
|
|
|
|
# Set tenant context in request state
|
2025-07-19 17:49:03 +02:00
|
|
|
request.state.tenant_id = tenant_id
|
2025-07-26 18:46:52 +02:00
|
|
|
request.state.tenant_verified = True
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Tenant access verified",
|
|
|
|
|
user_id=user_context["user_id"],
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
path=request.url.path)
|
|
|
|
|
|
|
|
|
|
# ✅ STEP 5: Inject user context into request
|
2025-07-19 17:49:03 +02:00
|
|
|
request.state.user = user_context
|
|
|
|
|
request.state.authenticated = True
|
2025-07-26 18:46:52 +02:00
|
|
|
|
|
|
|
|
# ✅ STEP 6: Add context headers for downstream services
|
|
|
|
|
self._inject_context_headers(request, user_context, tenant_id)
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Authenticated request",
|
|
|
|
|
user_email=user_context['email'],
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
path=request.url.path)
|
|
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
return await call_next(request)
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
def _is_public_route(self, path: str) -> bool:
|
2025-07-19 17:49:03 +02:00
|
|
|
"""Check if route requires authentication"""
|
2025-07-17 19:54:04 +02:00
|
|
|
return any(path.startswith(route) for route in PUBLIC_ROUTES)
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
def _extract_token(self, request: Request) -> Optional[str]:
|
2025-07-19 17:49:03 +02:00
|
|
|
"""Extract JWT token from Authorization header"""
|
2025-07-17 19:54:04 +02:00
|
|
|
auth_header = request.headers.get("Authorization")
|
|
|
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
|
|
|
return auth_header.split(" ")[1]
|
|
|
|
|
return None
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
async def _verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
2025-07-26 18:46:52 +02:00
|
|
|
"""Verify JWT token with fallback strategy"""
|
2025-07-19 17:49:03 +02:00
|
|
|
|
2025-07-26 18:46:52 +02:00
|
|
|
# Try local JWT validation first (fast)
|
2025-07-19 17:49:03 +02:00
|
|
|
try:
|
|
|
|
|
payload = jwt_handler.verify_token(token)
|
|
|
|
|
if payload and self._validate_token_payload(payload):
|
|
|
|
|
logger.debug("Token validated locally")
|
|
|
|
|
return payload
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.debug(f"Local token validation failed: {e}")
|
|
|
|
|
|
2025-07-26 18:46:52 +02:00
|
|
|
# Check cache for recently validated tokens
|
2025-07-19 17:49:03 +02:00
|
|
|
if self.redis_client:
|
|
|
|
|
try:
|
|
|
|
|
cached_user = await self._get_cached_user(token)
|
|
|
|
|
if cached_user:
|
|
|
|
|
logger.debug("Token found in cache")
|
|
|
|
|
return cached_user
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Cache lookup failed: {e}")
|
|
|
|
|
|
2025-07-26 18:46:52 +02:00
|
|
|
# Verify with auth service (authoritative)
|
2025-07-19 17:49:03 +02:00
|
|
|
try:
|
|
|
|
|
user_context = await self._verify_with_auth_service(token)
|
|
|
|
|
if user_context:
|
|
|
|
|
# Cache successful validation
|
|
|
|
|
if self.redis_client:
|
|
|
|
|
await self._cache_user(token, user_context)
|
|
|
|
|
logger.debug("Token validated by auth service")
|
|
|
|
|
return user_context
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Auth service validation failed: {e}")
|
2025-07-26 18:46:52 +02:00
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _validate_token_payload(self, payload: Dict[str, Any]) -> bool:
|
|
|
|
|
"""Validate JWT payload has required fields"""
|
|
|
|
|
required_fields = ["user_id", "email", "exp"]
|
|
|
|
|
return all(field in payload for field in required_fields)
|
|
|
|
|
|
|
|
|
|
async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]:
|
2025-07-17 19:54:04 +02:00
|
|
|
"""Verify token with auth service"""
|
|
|
|
|
try:
|
2025-07-26 18:46:52 +02:00
|
|
|
import httpx
|
2025-07-19 17:49:03 +02:00
|
|
|
async with httpx.AsyncClient(timeout=3.0) as client:
|
2025-07-17 19:54:04 +02:00
|
|
|
response = await client.post(
|
2025-07-18 16:48:49 +02:00
|
|
|
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
|
2025-07-17 19:54:04 +02:00
|
|
|
headers={"Authorization": f"Bearer {token}"}
|
|
|
|
|
)
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
if response.status_code == 200:
|
2025-07-19 17:49:03 +02:00
|
|
|
return response.json()
|
2025-07-17 19:54:04 +02:00
|
|
|
else:
|
2025-07-19 17:49:03 +02:00
|
|
|
logger.warning(f"Auth service returned {response.status_code}")
|
2025-07-17 19:54:04 +02:00
|
|
|
return None
|
|
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Auth service error: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]:
|
|
|
|
|
"""Get user context from cache"""
|
|
|
|
|
if not self.redis_client:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
cache_key = f"auth:token:{hash(token)}"
|
2025-07-26 18:46:52 +02:00
|
|
|
try:
|
|
|
|
|
cached_data = await self.redis_client.get(cache_key)
|
|
|
|
|
if cached_data:
|
|
|
|
|
import json
|
|
|
|
|
if isinstance(cached_data, bytes):
|
|
|
|
|
cached_data = cached_data.decode()
|
|
|
|
|
return json.loads(cached_data)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Cache get failed: {e}")
|
2025-07-19 17:49:03 +02:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
async def _cache_user(self, token: str, user_context: Dict[str, Any], ttl: int = 300):
|
|
|
|
|
"""Cache user context for 5 minutes"""
|
|
|
|
|
if not self.redis_client:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
cache_key = f"auth:token:{hash(token)}"
|
2025-07-26 18:46:52 +02:00
|
|
|
try:
|
|
|
|
|
import json
|
|
|
|
|
await self.redis_client.setex(cache_key, ttl, json.dumps(user_context))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Cache set failed: {e}")
|
2025-07-19 17:49:03 +02:00
|
|
|
|
2025-07-26 18:46:52 +02:00
|
|
|
def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str]):
|
|
|
|
|
"""Inject authentication and tenant headers for downstream services"""
|
|
|
|
|
|
|
|
|
|
# Remove any existing auth headers to prevent spoofing
|
|
|
|
|
headers_to_remove = [
|
|
|
|
|
"x-user-id", "x-user-email", "x-user-role",
|
|
|
|
|
"x-tenant-id", "x-tenant-verified", "x-authenticated"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
for header in headers_to_remove:
|
|
|
|
|
request.headers.__dict__["_list"] = [
|
|
|
|
|
(k, v) for k, v in request.headers.raw
|
|
|
|
|
if k.lower() != header.lower()
|
2025-07-20 07:24:04 +02:00
|
|
|
]
|
2025-07-26 18:46:52 +02:00
|
|
|
|
|
|
|
|
# Inject new headers
|
|
|
|
|
new_headers = [
|
|
|
|
|
(b"x-authenticated", b"true"),
|
|
|
|
|
(b"x-user-id", str(user_context.get("user_id", "")).encode()),
|
|
|
|
|
(b"x-user-email", str(user_context.get("email", "")).encode()),
|
|
|
|
|
(b"x-user-role", str(user_context.get("role", "user")).encode()),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Add tenant context if verified
|
|
|
|
|
if tenant_id:
|
|
|
|
|
new_headers.extend([
|
|
|
|
|
(b"x-tenant-id", tenant_id.encode()),
|
|
|
|
|
(b"x-tenant-verified", b"true")
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# Add headers to request
|
|
|
|
|
request.headers.__dict__["_list"].extend(new_headers)
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Injected context headers",
|
|
|
|
|
user_id=user_context.get("user_id"),
|
|
|
|
|
tenant_id=tenant_id)
|