Files
bakery-ia/gateway/app/middleware/auth.py
2025-10-23 07:44:54 +02:00

520 lines
22 KiB
Python

# gateway/app/middleware/auth.py
"""
Enhanced Authentication Middleware for API Gateway with Tenant Access Control
FIXED VERSION - Proper JWT verification and token structure handling
"""
import structlog
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from typing import Optional, Dict, Any
import httpx
import json
from app.core.config import settings
from shared.auth.jwt_handler import JWTHandler
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
logger = structlog.get_logger()
# JWT handler for local token validation - using SAME configuration as auth service
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
# Routes that don't require authentication
PUBLIC_ROUTES = [
"/health",
"/metrics",
"/docs",
"/redoc",
"/openapi.json",
"/api/v1/auth/login",
"/api/v1/auth/register",
"/api/v1/auth/refresh",
"/api/v1/auth/verify",
"/api/v1/nominatim/search",
"/api/v1/plans",
"/api/v1/demo/accounts",
"/api/v1/demo/sessions"
]
class AuthMiddleware(BaseHTTPMiddleware):
"""
Enhanced Authentication Middleware with Tenant Access Control
"""
def __init__(self, app, redis_client=None):
super().__init__(app)
self.redis_client = redis_client # For caching and rate limiting
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with enhanced authentication and tenant access control"""
# Skip authentication for OPTIONS requests (CORS preflight)
if request.method == "OPTIONS":
return await call_next(request)
# Skip authentication for public routes
if self._is_public_route(request.url.path):
return await call_next(request)
# ✅ Check if demo middleware already set user context OR check query param for SSE
demo_session_header = request.headers.get("X-Demo-Session-Id")
demo_session_query = request.query_params.get("demo_session_id") # For SSE endpoint
logger.info(f"Auth check - path: {request.url.path}, demo_header: {demo_session_header}, demo_query: {demo_session_query}, has_demo_state: {hasattr(request.state, 'is_demo_session')}")
# For SSE endpoint with demo_session_id in query params, validate it here
if request.url.path == "/api/events" and demo_session_query and not hasattr(request.state, "is_demo_session"):
logger.info(f"SSE endpoint with demo_session_id query param: {demo_session_query}")
# Validate demo session via demo-session service
import httpx
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://demo-session-service:8000/api/v1/demo/sessions/{demo_session_query}",
headers={"X-Internal-API-Key": "dev-internal-key-change-in-production"}
)
if response.status_code == 200:
session_data = response.json()
# Set demo session context
request.state.is_demo_session = True
request.state.user = {
"user_id": f"demo-user-{demo_session_query}",
"email": f"demo-{demo_session_query}@demo.local",
"tenant_id": session_data.get("virtual_tenant_id"),
"demo_session_id": demo_session_query,
}
request.state.tenant_id = session_data.get("virtual_tenant_id")
logger.info(f"✅ Demo session validated for SSE: {demo_session_query}")
else:
logger.warning(f"Invalid demo session for SSE: {demo_session_query}")
return JSONResponse(
status_code=401,
content={"detail": "Invalid demo session"}
)
except Exception as e:
logger.error(f"Failed to validate demo session for SSE: {e}")
return JSONResponse(
status_code=503,
content={"detail": "Demo session service unavailable"}
)
if hasattr(request.state, "is_demo_session") and request.state.is_demo_session:
if hasattr(request.state, "user") and request.state.user:
logger.info(f"✅ Demo session authenticated for route: {request.url.path}")
# Demo middleware already validated and set user context
# But we still need to inject context headers for downstream services
user_context = request.state.user
tenant_id = user_context.get("tenant_id") or getattr(request.state, "tenant_id", None)
# Inject subscription tier for demo sessions - always enterprise tier for full feature access
user_context["subscription_tier"] = "enterprise"
logger.debug(f"Demo session subscription tier set to enterprise", tenant_id=tenant_id)
self._inject_context_headers(request, user_context, tenant_id)
return await call_next(request)
# ✅ STEP 1: Extract and validate JWT token
token = self._extract_token(request)
if not token:
logger.warning(f"❌ Missing token for protected route: {request.url.path}, demo_header: {demo_session_header}")
return JSONResponse(
status_code=401,
content={"detail": "Authentication required"}
)
# ✅ STEP 2: Verify token and get user context
user_context = await self._verify_token(token, request)
if not user_context:
logger.warning(f"Invalid token for route: {request.url.path}")
return JSONResponse(
status_code=401,
content={"detail": "User not authenticated"}
)
# ✅ STEP 3: Extract tenant context from URL using shared utility
tenant_id = extract_tenant_id_from_path(request.url.path)
# ✅ STEP 4: Verify tenant access if this is a tenant-scoped route
if tenant_id and is_tenant_scoped_path(request.url.path):
# Use TenantAccessManager for gateway-level verification with caching
if self.redis_client and tenant_access_manager.redis_client is None:
tenant_access_manager.redis_client = self.redis_client
has_access = await tenant_access_manager.verify_basic_tenant_access(
user_context["user_id"],
tenant_id
)
if not has_access:
logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}")
return JSONResponse(
status_code=403,
content={"detail": f"Access denied to tenant {tenant_id}"}
)
# Get tenant subscription tier and inject into user context
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
if subscription_tier:
user_context["subscription_tier"] = subscription_tier
# Set tenant context in request state
request.state.tenant_id = tenant_id
request.state.tenant_verified = True
logger.debug(f"Tenant access verified",
user_id=user_context["user_id"],
tenant_id=tenant_id,
subscription_tier=subscription_tier,
path=request.url.path)
# ✅ STEP 5: Inject user context into request
request.state.user = user_context
request.state.authenticated = True
# ✅ STEP 6: Add context headers for downstream services
self._inject_context_headers(request, user_context, tenant_id)
logger.debug(f"Authenticated request",
user_email=user_context['email'],
tenant_id=tenant_id,
path=request.url.path)
# Process the request
response = await call_next(request)
# Add token expiry warning header if token is near expiry
if hasattr(request.state, 'token_near_expiry') and request.state.token_near_expiry:
response.headers["X-Token-Refresh-Suggested"] = "true"
return response
def _is_public_route(self, path: str) -> bool:
"""Check if route requires authentication"""
return any(path.startswith(route) for route in PUBLIC_ROUTES)
def _extract_token(self, request: Request) -> Optional[str]:
"""
Extract JWT token from Authorization header or query params for SSE.
For SSE endpoints (/api/events), browsers' EventSource API cannot send
custom headers, so we must accept token as query parameter.
For all other routes, token must be in Authorization header (more secure).
Security note: Query param tokens are logged. Use short expiry and filter logs.
"""
# SSE endpoint exception: token in query param (EventSource API limitation)
if request.url.path == "/api/events":
token = request.query_params.get("token")
if token:
logger.debug("Token extracted from query param for SSE endpoint")
return token
logger.warning("SSE request missing token in query param")
return None
# Standard authentication: Authorization header for all other routes
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.split(" ")[1]
return None
async def _verify_token(self, token: str, request: Request = None) -> Optional[Dict[str, Any]]:
"""
Verify JWT token with improved fallback strategy
FIXED: Better error handling and token structure validation
"""
# Strategy 1: Try local JWT validation first (fast)
try:
payload = jwt_handler.verify_token(token)
if payload and self._validate_token_payload(payload):
logger.debug("Token validated locally")
# Check if token is near expiry and set flag for response header
if request:
import time
exp_time = payload.get("exp", 0)
current_time = time.time()
time_until_expiry = exp_time - current_time
if time_until_expiry < 300: # 5 minutes
request.state.token_near_expiry = True
# Convert JWT payload to user context format
return self._jwt_payload_to_user_context(payload)
except Exception as e:
logger.debug(f"Local token validation failed: {e}")
# Strategy 2: Check cache for recently validated tokens
if self.redis_client:
try:
cached_user = await self._get_cached_user(token)
if cached_user:
logger.debug("Token found in cache")
return cached_user
except Exception as e:
logger.warning(f"Cache lookup failed: {e}")
# Strategy 3: Verify with auth service (authoritative)
try:
user_context = await self._verify_with_auth_service(token)
if user_context:
# Cache successful validation
if self.redis_client:
await self._cache_user(token, user_context)
logger.debug("Token validated by auth service")
return user_context
except Exception as e:
logger.error(f"Auth service validation failed: {e}")
return None
def _validate_token_payload(self, payload: Dict[str, Any]) -> bool:
"""
Validate JWT payload has required fields
FIXED: Updated to match actual token structure from auth service
"""
required_fields = ["user_id", "email", "exp", "type"]
missing_fields = [field for field in required_fields if field not in payload]
if missing_fields:
logger.warning(f"Token payload missing fields: {missing_fields}")
return False
# Validate token type
token_type = payload.get("type")
if token_type not in ["access", "service"]:
logger.warning(f"Invalid token type: {payload.get('type')}")
return False
# Check if token is near expiry (within 5 minutes) and log warning
import time
exp_time = payload.get("exp", 0)
current_time = time.time()
time_until_expiry = exp_time - current_time
if time_until_expiry < 300: # 5 minutes
logger.warning(f"Token expires in {int(time_until_expiry)} seconds for user {payload.get('email')}")
return True
def _jwt_payload_to_user_context(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert JWT payload to user context format
FIXED: Proper mapping between JWT structure and user context
"""
base_context = {
"user_id": payload["user_id"],
"email": payload["email"],
"exp": payload["exp"],
"valid": True,
"role": payload.get("role", "user"),
}
if payload.get("service"):
service_name = payload["service"]
base_context["service"] = service_name
base_context["type"] = "service"
base_context["role"] = "admin"
base_context["user_id"] = f"{service_name}-service"
base_context["email"] = f"{service_name}-service@internal"
logger.debug(f"Service authentication: {payload['service']}")
return base_context
async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]:
"""
Verify token with auth service
FIXED: Improved error handling and response parsing
"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.post(
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
headers={"Authorization": f"Bearer {token}"}
)
if response.status_code == 200:
auth_response = response.json()
# Validate auth service response structure
if auth_response.get("valid") and auth_response.get("user_id"):
return {
"user_id": auth_response["user_id"],
"email": auth_response["email"],
"exp": auth_response.get("exp"),
"valid": True
}
else:
logger.warning(f"Auth service returned invalid response: {auth_response}")
return None
else:
logger.warning(f"Auth service returned {response.status_code}: {response.text}")
return None
except httpx.TimeoutException:
logger.error("Auth service timeout during token verification")
return None
except Exception as e:
logger.error(f"Auth service error: {e}")
return None
async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]:
"""
Get user context from cache
FIXED: Better error handling and JSON parsing
"""
if not self.redis_client:
return None
cache_key = f"auth:token:{hash(token) % 1000000}" # Use modulo for shorter keys
try:
cached_data = await self.redis_client.get(cache_key)
if cached_data:
if isinstance(cached_data, bytes):
cached_data = cached_data.decode()
return json.loads(cached_data)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse cached user data: {e}")
except Exception as e:
logger.warning(f"Cache lookup error: {e}")
return None
async def _cache_user(self, token: str, user_context: Dict[str, Any]) -> None:
"""
Cache user context
FIXED: Better error handling and expiration
"""
if not self.redis_client:
return
cache_key = f"auth:token:{hash(token) % 1000000}"
try:
# Cache for 5 minutes (shorter than token expiry)
await self.redis_client.setex(
cache_key,
300, # 5 minutes
json.dumps(user_context)
)
except Exception as e:
logger.warning(f"Failed to cache user context: {e}")
def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
"""
Inject user and tenant context headers for downstream services
ENHANCED: Added logging to verify header injection
"""
# Log what we're injecting for debugging
logger.debug(
"Injecting context headers",
user_id=user_context.get("user_id"),
user_type=user_context.get("type", ""),
service_name=user_context.get("service", ""),
role=user_context.get("role", ""),
tenant_id=tenant_id,
path=request.url.path
)
# Add user context headers
request.headers.__dict__["_list"].append((
b"x-user-id", user_context["user_id"].encode()
))
request.headers.__dict__["_list"].append((
b"x-user-email", user_context["email"].encode()
))
user_role = user_context.get("role", "user")
request.headers.__dict__["_list"].append((
b"x-user-role", user_role.encode()
))
user_type = user_context.get("type", "")
if user_type:
request.headers.__dict__["_list"].append((
b"x-user-type", user_type.encode()
))
service_name = user_context.get("service", "")
if service_name:
request.headers.__dict__["_list"].append((
b"x-service-name", service_name.encode()
))
# Add tenant context if available
if tenant_id:
request.headers.__dict__["_list"].append((
b"x-tenant-id", tenant_id.encode()
))
# Add subscription tier if available
subscription_tier = user_context.get("subscription_tier", "")
if subscription_tier:
request.headers.__dict__["_list"].append((
b"x-subscription-tier", subscription_tier.encode()
))
# Add gateway identification
request.headers.__dict__["_list"].append((
b"x-forwarded-by", b"bakery-gateway"
))
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
"""
Get tenant subscription tier from tenant service
Args:
tenant_id: Tenant ID
request: FastAPI request for headers
Returns:
Subscription tier string or None
"""
try:
# Check cache first
if self.redis_client:
cache_key = f"tenant:tier:{tenant_id}"
try:
cached_tier = await self.redis_client.get(cache_key)
if cached_tier:
if isinstance(cached_tier, bytes):
cached_tier = cached_tier.decode()
logger.debug("Subscription tier from cache", tenant_id=tenant_id, tier=cached_tier)
return cached_tier
except Exception as e:
logger.warning(f"Cache lookup failed for tenant tier: {e}")
# Get from tenant service
async with httpx.AsyncClient(timeout=5.0) as client:
headers = {"Authorization": request.headers.get("Authorization", "")}
response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}",
headers=headers
)
if response.status_code == 200:
tenant_data = response.json()
subscription_tier = tenant_data.get("subscription_tier", "basic")
# Cache for 5 minutes
if self.redis_client:
try:
await self.redis_client.setex(
f"tenant:tier:{tenant_id}",
300, # 5 minutes
subscription_tier
)
except Exception as e:
logger.warning(f"Failed to cache tenant tier: {e}")
logger.debug("Subscription tier from service", tenant_id=tenant_id, tier=subscription_tier)
return subscription_tier
else:
logger.warning(f"Failed to get tenant subscription tier: {response.status_code}")
return "basic" # Default to basic
except Exception as e:
logger.error(f"Error getting tenant subscription tier: {e}")
return "basic" # Default to basic on error