463 lines
18 KiB
Python
463 lines
18 KiB
Python
# gateway/app/middleware/auth.py
|
|
"""
|
|
Enhanced Authentication Middleware for API Gateway with Tenant Access Control
|
|
FIXED VERSION - Proper JWT verification and token structure handling
|
|
"""
|
|
|
|
import structlog
|
|
from fastapi import Request, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import Response
|
|
from typing import Optional, Dict, Any
|
|
import httpx
|
|
import json
|
|
|
|
from app.core.config import settings
|
|
from shared.auth.jwt_handler import JWTHandler
|
|
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
# JWT handler for local token validation - using SAME configuration as auth service
|
|
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
|
|
|
# Routes that don't require authentication
|
|
PUBLIC_ROUTES = [
|
|
"/health",
|
|
"/metrics",
|
|
"/docs",
|
|
"/redoc",
|
|
"/openapi.json",
|
|
"/api/v1/auth/login",
|
|
"/api/v1/auth/register",
|
|
"/api/v1/auth/refresh",
|
|
"/api/v1/auth/verify",
|
|
"/api/v1/nominatim/search",
|
|
"/api/v1/plans",
|
|
"/api/v1/demo/accounts",
|
|
"/api/v1/demo/sessions"
|
|
]
|
|
|
|
class AuthMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
Enhanced Authentication Middleware with Tenant Access Control
|
|
"""
|
|
|
|
def __init__(self, app, redis_client=None):
|
|
super().__init__(app)
|
|
self.redis_client = redis_client # For caching and rate limiting
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
"""Process request with enhanced authentication and tenant access control"""
|
|
|
|
# Skip authentication for OPTIONS requests (CORS preflight)
|
|
if request.method == "OPTIONS":
|
|
return await call_next(request)
|
|
|
|
# Skip authentication for public routes
|
|
if self._is_public_route(request.url.path):
|
|
return await call_next(request)
|
|
|
|
# ✅ Check if demo middleware already set user context
|
|
demo_session_header = request.headers.get("X-Demo-Session-Id")
|
|
logger.info(f"Auth check - path: {request.url.path}, demo_header: {demo_session_header}, has_demo_state: {hasattr(request.state, 'is_demo_session')}")
|
|
|
|
if hasattr(request.state, "is_demo_session") and request.state.is_demo_session:
|
|
if hasattr(request.state, "user") and request.state.user:
|
|
logger.info(f"✅ Demo session authenticated for route: {request.url.path}")
|
|
# Demo middleware already validated and set user context, pass through
|
|
return await call_next(request)
|
|
|
|
# ✅ STEP 1: Extract and validate JWT token
|
|
token = self._extract_token(request)
|
|
if not token:
|
|
logger.warning(f"❌ Missing token for protected route: {request.url.path}, demo_header: {demo_session_header}")
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"detail": "Authentication required"}
|
|
)
|
|
|
|
# ✅ STEP 2: Verify token and get user context
|
|
user_context = await self._verify_token(token, request)
|
|
if not user_context:
|
|
logger.warning(f"Invalid token for route: {request.url.path}")
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"detail": "User not authenticated"}
|
|
)
|
|
|
|
# ✅ STEP 3: Extract tenant context from URL using shared utility
|
|
tenant_id = extract_tenant_id_from_path(request.url.path)
|
|
|
|
# ✅ STEP 4: Verify tenant access if this is a tenant-scoped route
|
|
if tenant_id and is_tenant_scoped_path(request.url.path):
|
|
# Use TenantAccessManager for gateway-level verification with caching
|
|
if self.redis_client and tenant_access_manager.redis_client is None:
|
|
tenant_access_manager.redis_client = self.redis_client
|
|
|
|
has_access = await tenant_access_manager.verify_basic_tenant_access(
|
|
user_context["user_id"],
|
|
tenant_id
|
|
)
|
|
|
|
if not has_access:
|
|
logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}")
|
|
return JSONResponse(
|
|
status_code=403,
|
|
content={"detail": f"Access denied to tenant {tenant_id}"}
|
|
)
|
|
|
|
# Get tenant subscription tier and inject into user context
|
|
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
|
|
if subscription_tier:
|
|
user_context["subscription_tier"] = subscription_tier
|
|
|
|
# Set tenant context in request state
|
|
request.state.tenant_id = tenant_id
|
|
request.state.tenant_verified = True
|
|
|
|
logger.debug(f"Tenant access verified",
|
|
user_id=user_context["user_id"],
|
|
tenant_id=tenant_id,
|
|
subscription_tier=subscription_tier,
|
|
path=request.url.path)
|
|
|
|
# ✅ STEP 5: Inject user context into request
|
|
request.state.user = user_context
|
|
request.state.authenticated = True
|
|
|
|
# ✅ STEP 6: Add context headers for downstream services
|
|
self._inject_context_headers(request, user_context, tenant_id)
|
|
|
|
logger.debug(f"Authenticated request",
|
|
user_email=user_context['email'],
|
|
tenant_id=tenant_id,
|
|
path=request.url.path)
|
|
|
|
# Process the request
|
|
response = await call_next(request)
|
|
|
|
# Add token expiry warning header if token is near expiry
|
|
if hasattr(request.state, 'token_near_expiry') and request.state.token_near_expiry:
|
|
response.headers["X-Token-Refresh-Suggested"] = "true"
|
|
|
|
return response
|
|
|
|
def _is_public_route(self, path: str) -> bool:
|
|
"""Check if route requires authentication"""
|
|
return any(path.startswith(route) for route in PUBLIC_ROUTES)
|
|
|
|
def _extract_token(self, request: Request) -> Optional[str]:
|
|
"""
|
|
Extract JWT token from Authorization header or query params for SSE.
|
|
|
|
For SSE endpoints (/api/events), browsers' EventSource API cannot send
|
|
custom headers, so we must accept token as query parameter.
|
|
For all other routes, token must be in Authorization header (more secure).
|
|
|
|
Security note: Query param tokens are logged. Use short expiry and filter logs.
|
|
"""
|
|
# SSE endpoint exception: token in query param (EventSource API limitation)
|
|
if request.url.path == "/api/events":
|
|
token = request.query_params.get("token")
|
|
if token:
|
|
logger.debug("Token extracted from query param for SSE endpoint")
|
|
return token
|
|
logger.warning("SSE request missing token in query param")
|
|
return None
|
|
|
|
# Standard authentication: Authorization header for all other routes
|
|
auth_header = request.headers.get("Authorization")
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
return auth_header.split(" ")[1]
|
|
|
|
return None
|
|
|
|
async def _verify_token(self, token: str, request: Request = None) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Verify JWT token with improved fallback strategy
|
|
FIXED: Better error handling and token structure validation
|
|
"""
|
|
|
|
# Strategy 1: Try local JWT validation first (fast)
|
|
try:
|
|
payload = jwt_handler.verify_token(token)
|
|
if payload and self._validate_token_payload(payload):
|
|
logger.debug("Token validated locally")
|
|
|
|
# Check if token is near expiry and set flag for response header
|
|
if request:
|
|
import time
|
|
exp_time = payload.get("exp", 0)
|
|
current_time = time.time()
|
|
time_until_expiry = exp_time - current_time
|
|
|
|
if time_until_expiry < 300: # 5 minutes
|
|
request.state.token_near_expiry = True
|
|
|
|
# Convert JWT payload to user context format
|
|
return self._jwt_payload_to_user_context(payload)
|
|
except Exception as e:
|
|
logger.debug(f"Local token validation failed: {e}")
|
|
|
|
# Strategy 2: Check cache for recently validated tokens
|
|
if self.redis_client:
|
|
try:
|
|
cached_user = await self._get_cached_user(token)
|
|
if cached_user:
|
|
logger.debug("Token found in cache")
|
|
return cached_user
|
|
except Exception as e:
|
|
logger.warning(f"Cache lookup failed: {e}")
|
|
|
|
# Strategy 3: Verify with auth service (authoritative)
|
|
try:
|
|
user_context = await self._verify_with_auth_service(token)
|
|
if user_context:
|
|
# Cache successful validation
|
|
if self.redis_client:
|
|
await self._cache_user(token, user_context)
|
|
logger.debug("Token validated by auth service")
|
|
return user_context
|
|
except Exception as e:
|
|
logger.error(f"Auth service validation failed: {e}")
|
|
|
|
return None
|
|
|
|
def _validate_token_payload(self, payload: Dict[str, Any]) -> bool:
|
|
"""
|
|
Validate JWT payload has required fields
|
|
FIXED: Updated to match actual token structure from auth service
|
|
"""
|
|
required_fields = ["user_id", "email", "exp", "type"]
|
|
missing_fields = [field for field in required_fields if field not in payload]
|
|
|
|
if missing_fields:
|
|
logger.warning(f"Token payload missing fields: {missing_fields}")
|
|
return False
|
|
|
|
# Validate token type
|
|
token_type = payload.get("type")
|
|
if token_type not in ["access", "service"]:
|
|
logger.warning(f"Invalid token type: {payload.get('type')}")
|
|
return False
|
|
|
|
# Check if token is near expiry (within 5 minutes) and log warning
|
|
import time
|
|
exp_time = payload.get("exp", 0)
|
|
current_time = time.time()
|
|
time_until_expiry = exp_time - current_time
|
|
|
|
if time_until_expiry < 300: # 5 minutes
|
|
logger.warning(f"Token expires in {int(time_until_expiry)} seconds for user {payload.get('email')}")
|
|
|
|
return True
|
|
|
|
def _jwt_payload_to_user_context(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Convert JWT payload to user context format
|
|
FIXED: Proper mapping between JWT structure and user context
|
|
"""
|
|
base_context = {
|
|
"user_id": payload["user_id"],
|
|
"email": payload["email"],
|
|
"exp": payload["exp"],
|
|
"valid": True,
|
|
"role": payload.get("role", "user"),
|
|
}
|
|
|
|
if payload.get("service"):
|
|
service_name = payload["service"]
|
|
base_context["service"] = service_name
|
|
base_context["type"] = "service"
|
|
base_context["role"] = "admin"
|
|
base_context["user_id"] = f"{service_name}-service"
|
|
base_context["email"] = f"{service_name}-service@internal"
|
|
logger.debug(f"Service authentication: {payload['service']}")
|
|
|
|
return base_context
|
|
|
|
async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Verify token with auth service
|
|
FIXED: Improved error handling and response parsing
|
|
"""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
response = await client.post(
|
|
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
|
|
headers={"Authorization": f"Bearer {token}"}
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
auth_response = response.json()
|
|
|
|
# Validate auth service response structure
|
|
if auth_response.get("valid") and auth_response.get("user_id"):
|
|
return {
|
|
"user_id": auth_response["user_id"],
|
|
"email": auth_response["email"],
|
|
"exp": auth_response.get("exp"),
|
|
"valid": True
|
|
}
|
|
else:
|
|
logger.warning(f"Auth service returned invalid response: {auth_response}")
|
|
return None
|
|
else:
|
|
logger.warning(f"Auth service returned {response.status_code}: {response.text}")
|
|
return None
|
|
|
|
except httpx.TimeoutException:
|
|
logger.error("Auth service timeout during token verification")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Auth service error: {e}")
|
|
return None
|
|
|
|
async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get user context from cache
|
|
FIXED: Better error handling and JSON parsing
|
|
"""
|
|
if not self.redis_client:
|
|
return None
|
|
|
|
cache_key = f"auth:token:{hash(token) % 1000000}" # Use modulo for shorter keys
|
|
try:
|
|
cached_data = await self.redis_client.get(cache_key)
|
|
if cached_data:
|
|
if isinstance(cached_data, bytes):
|
|
cached_data = cached_data.decode()
|
|
return json.loads(cached_data)
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f"Failed to parse cached user data: {e}")
|
|
except Exception as e:
|
|
logger.warning(f"Cache lookup error: {e}")
|
|
|
|
return None
|
|
|
|
async def _cache_user(self, token: str, user_context: Dict[str, Any]) -> None:
|
|
"""
|
|
Cache user context
|
|
FIXED: Better error handling and expiration
|
|
"""
|
|
if not self.redis_client:
|
|
return
|
|
|
|
cache_key = f"auth:token:{hash(token) % 1000000}"
|
|
try:
|
|
# Cache for 5 minutes (shorter than token expiry)
|
|
await self.redis_client.setex(
|
|
cache_key,
|
|
300, # 5 minutes
|
|
json.dumps(user_context)
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cache user context: {e}")
|
|
|
|
def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
|
|
"""
|
|
Inject user and tenant context headers for downstream services
|
|
FIXED: Proper header injection
|
|
"""
|
|
# Add user context headers
|
|
request.headers.__dict__["_list"].append((
|
|
b"x-user-id", user_context["user_id"].encode()
|
|
))
|
|
request.headers.__dict__["_list"].append((
|
|
b"x-user-email", user_context["email"].encode()
|
|
))
|
|
|
|
user_role = user_context.get("role", "user")
|
|
request.headers.__dict__["_list"].append((
|
|
b"x-user-role", user_role.encode()
|
|
))
|
|
|
|
user_type = user_context.get("type", "")
|
|
if user_type:
|
|
request.headers.__dict__["_list"].append((
|
|
b"x-user-type", user_type.encode()
|
|
))
|
|
|
|
service_name = user_context.get("service", "")
|
|
if service_name:
|
|
request.headers.__dict__["_list"].append((
|
|
b"x-service-name", service_name.encode()
|
|
))
|
|
|
|
# Add tenant context if available
|
|
if tenant_id:
|
|
request.headers.__dict__["_list"].append((
|
|
b"x-tenant-id", tenant_id.encode()
|
|
))
|
|
|
|
# Add subscription tier if available
|
|
subscription_tier = user_context.get("subscription_tier", "")
|
|
if subscription_tier:
|
|
request.headers.__dict__["_list"].append((
|
|
b"x-subscription-tier", subscription_tier.encode()
|
|
))
|
|
|
|
# Add gateway identification
|
|
request.headers.__dict__["_list"].append((
|
|
b"x-forwarded-by", b"bakery-gateway"
|
|
))
|
|
|
|
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
|
|
"""
|
|
Get tenant subscription tier from tenant service
|
|
|
|
Args:
|
|
tenant_id: Tenant ID
|
|
request: FastAPI request for headers
|
|
|
|
Returns:
|
|
Subscription tier string or None
|
|
"""
|
|
try:
|
|
# Check cache first
|
|
if self.redis_client:
|
|
cache_key = f"tenant:tier:{tenant_id}"
|
|
try:
|
|
cached_tier = await self.redis_client.get(cache_key)
|
|
if cached_tier:
|
|
if isinstance(cached_tier, bytes):
|
|
cached_tier = cached_tier.decode()
|
|
logger.debug("Subscription tier from cache", tenant_id=tenant_id, tier=cached_tier)
|
|
return cached_tier
|
|
except Exception as e:
|
|
logger.warning(f"Cache lookup failed for tenant tier: {e}")
|
|
|
|
# Get from tenant service
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
headers = {"Authorization": request.headers.get("Authorization", "")}
|
|
response = await client.get(
|
|
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}",
|
|
headers=headers
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
tenant_data = response.json()
|
|
subscription_tier = tenant_data.get("subscription_tier", "basic")
|
|
|
|
# Cache for 5 minutes
|
|
if self.redis_client:
|
|
try:
|
|
await self.redis_client.setex(
|
|
f"tenant:tier:{tenant_id}",
|
|
300, # 5 minutes
|
|
subscription_tier
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cache tenant tier: {e}")
|
|
|
|
logger.debug("Subscription tier from service", tenant_id=tenant_id, tier=subscription_tier)
|
|
return subscription_tier
|
|
else:
|
|
logger.warning(f"Failed to get tenant subscription tier: {response.status_code}")
|
|
return "basic" # Default to basic
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting tenant subscription tier: {e}")
|
|
return "basic" # Default to basic on error |