# gateway/app/middleware/auth.py """ Enhanced Authentication Middleware for API Gateway with Tenant Access Control FIXED VERSION - Proper JWT verification and token structure handling """ import structlog from fastapi import Request, HTTPException, status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response from typing import Optional, Dict, Any import httpx import json from app.core.config import settings from shared.auth.jwt_handler import JWTHandler from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path logger = structlog.get_logger() # JWT handler for local token validation - using SAME configuration as auth service jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM) # Routes that don't require authentication PUBLIC_ROUTES = [ "/health", "/metrics", "/docs", "/redoc", "/openapi.json", "/api/v1/auth/login", "/api/v1/auth/register", "/api/v1/auth/refresh", "/api/v1/auth/verify", "/api/v1/nominatim/search", "/api/v1/plans", "/api/v1/demo/accounts", "/api/v1/demo/sessions" ] # Routes accessible with demo session (no JWT required, just demo session header) DEMO_ACCESSIBLE_ROUTES = [ "/api/v1/tenants/", # All tenant endpoints accessible in demo mode ] class AuthMiddleware(BaseHTTPMiddleware): """ Enhanced Authentication Middleware with Tenant Access Control """ def __init__(self, app, redis_client=None): super().__init__(app) self.redis_client = redis_client # For caching and rate limiting async def dispatch(self, request: Request, call_next) -> Response: """Process request with enhanced authentication and tenant access control""" # Skip authentication for OPTIONS requests (CORS preflight) if request.method == "OPTIONS": return await call_next(request) # SECURITY: Remove any incoming x-subscription-* headers # These will be re-injected from verified JWT only sanitized_headers = [ (k, v) for k, v in request.headers.raw if not k.decode().lower().startswith('x-subscription-') and not k.decode().lower().startswith('x-user-') and not k.decode().lower().startswith('x-tenant-') ] request.headers.__dict__["_list"] = sanitized_headers # Skip authentication for public routes if self._is_public_route(request.url.path): return await call_next(request) # ✅ Check if demo middleware already set user context OR check query param for SSE demo_session_header = request.headers.get("X-Demo-Session-Id") demo_session_query = request.query_params.get("demo_session_id") # For SSE endpoint logger.info(f"Auth check - path: {request.url.path}, demo_header: {demo_session_header}, demo_query: {demo_session_query}, has_demo_state: {hasattr(request.state, 'is_demo_session')}") # For SSE endpoint with demo_session_id in query params, validate it here if request.url.path == "/api/events" and demo_session_query and not hasattr(request.state, "is_demo_session"): logger.info(f"SSE endpoint with demo_session_id query param: {demo_session_query}") # Validate demo session via demo-session service import httpx try: async with httpx.AsyncClient() as client: response = await client.get( f"http://demo-session-service:8000/api/v1/demo/sessions/{demo_session_query}", headers={"X-Internal-API-Key": "dev-internal-key-change-in-production"} ) if response.status_code == 200: session_data = response.json() # Set demo session context request.state.is_demo_session = True request.state.user = { "user_id": f"demo-user-{demo_session_query}", "email": f"demo-{demo_session_query}@demo.local", "tenant_id": session_data.get("virtual_tenant_id"), "demo_session_id": demo_session_query, } request.state.tenant_id = session_data.get("virtual_tenant_id") logger.info(f"✅ Demo session validated for SSE: {demo_session_query}") else: logger.warning(f"Invalid demo session for SSE: {demo_session_query}") return JSONResponse( status_code=401, content={"detail": "Invalid demo session"} ) except Exception as e: logger.error(f"Failed to validate demo session for SSE: {e}") return JSONResponse( status_code=503, content={"detail": "Demo session service unavailable"} ) if hasattr(request.state, "is_demo_session") and request.state.is_demo_session: if hasattr(request.state, "user") and request.state.user: logger.info(f"✅ Demo session authenticated for route: {request.url.path}") # Demo middleware already validated and set user context # But we still need to inject context headers for downstream services user_context = request.state.user tenant_id = user_context.get("tenant_id") or getattr(request.state, "tenant_id", None) # For demo sessions, get the actual subscription tier from the tenant service # instead of always defaulting to enterprise if not user_context.get("subscription_tier"): subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request) if subscription_tier: user_context["subscription_tier"] = subscription_tier else: # Fallback to enterprise for demo if no tier is found user_context["subscription_tier"] = "enterprise" logger.debug(f"Demo session subscription tier set to {user_context['subscription_tier']}", tenant_id=tenant_id) await self._inject_context_headers(request, user_context, tenant_id) return await call_next(request) # ✅ STEP 1: Extract and validate JWT token token = self._extract_token(request) if not token: logger.warning(f"❌ Missing token for protected route: {request.url.path}, demo_header: {demo_session_header}") return JSONResponse( status_code=401, content={"detail": "Authentication required"} ) # ✅ STEP 2: Verify token and get user context user_context = await self._verify_token(token, request) if not user_context: logger.warning(f"Invalid token for route: {request.url.path}") return JSONResponse( status_code=401, content={"detail": "User not authenticated"} ) # ✅ STEP 3: Extract tenant context from URL using shared utility tenant_id = extract_tenant_id_from_path(request.url.path) # ✅ STEP 4: Verify tenant access if this is a tenant-scoped route if tenant_id and is_tenant_scoped_path(request.url.path): # Use TenantAccessManager for gateway-level verification with caching if self.redis_client and tenant_access_manager.redis_client is None: tenant_access_manager.redis_client = self.redis_client has_access = await tenant_access_manager.verify_basic_tenant_access( user_context["user_id"], tenant_id ) if not has_access: logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}") return JSONResponse( status_code=403, content={"detail": f"Access denied to tenant {tenant_id}"} ) # Get tenant subscription tier and inject into user context # NEW: Use JWT data if available, skip HTTP call if user_context.get("subscription_from_jwt"): subscription_tier = user_context.get("subscription_tier") logger.debug("Using subscription tier from JWT", tier=subscription_tier) else: # Only for old tokens - remove after full rollout subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request) if subscription_tier: user_context["subscription_tier"] = subscription_tier # Check hierarchical access to determine access type and permissions hierarchical_access = await tenant_access_manager.verify_hierarchical_access( user_context["user_id"], tenant_id ) # Set tenant context in request state request.state.tenant_id = tenant_id request.state.tenant_verified = True request.state.tenant_access_type = hierarchical_access.get("access_type", "direct") request.state.can_view_children = hierarchical_access.get("can_view_children", False) logger.debug(f"Tenant access verified", user_id=user_context["user_id"], tenant_id=tenant_id, subscription_tier=subscription_tier, access_type=hierarchical_access.get("access_type"), can_view_children=hierarchical_access.get("can_view_children"), path=request.url.path) # ✅ STEP 5: Inject user context into request request.state.user = user_context request.state.authenticated = True # ✅ STEP 6: Add context headers for downstream services await self._inject_context_headers(request, user_context, tenant_id) logger.debug(f"Authenticated request", user_email=user_context['email'], tenant_id=tenant_id, path=request.url.path) # Process the request response = await call_next(request) # Add token expiry warning header if token is near expiry if hasattr(request.state, 'token_near_expiry') and request.state.token_near_expiry: response.headers["X-Token-Refresh-Suggested"] = "true" return response def _is_public_route(self, path: str) -> bool: """Check if route requires authentication""" return any(path.startswith(route) for route in PUBLIC_ROUTES) def _extract_token(self, request: Request) -> Optional[str]: """ Extract JWT token from Authorization header or query params for SSE. For SSE endpoints (/api/events), browsers' EventSource API cannot send custom headers, so we must accept token as query parameter. For all other routes, token must be in Authorization header (more secure). Security note: Query param tokens are logged. Use short expiry and filter logs. """ # SSE endpoint exception: token in query param (EventSource API limitation) if request.url.path == "/api/events": token = request.query_params.get("token") if token: logger.debug("Token extracted from query param for SSE endpoint") return token logger.warning("SSE request missing token in query param") return None # Standard authentication: Authorization header for all other routes auth_header = request.headers.get("Authorization") if auth_header and auth_header.startswith("Bearer "): return auth_header.split(" ")[1] return None async def _verify_token(self, token: str, request: Request = None) -> Optional[Dict[str, Any]]: """ Verify JWT token with improved fallback strategy FIXED: Better error handling and token structure validation """ # Strategy 1: Try local JWT validation first (fast) try: payload = jwt_handler.verify_token(token) if payload and self._validate_token_payload(payload): logger.debug("Token validated locally") # NEW: Check token freshness for subscription changes (async) if payload.get("tenant_id") and request: try: is_fresh = await self._verify_token_freshness(payload, payload["tenant_id"]) if not is_fresh: logger.warning("Stale token detected - subscription changed since token was issued", user_id=payload.get("user_id"), tenant_id=payload.get("tenant_id")) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token is stale - subscription has changed" ) except Exception as e: logger.warning("Token freshness check failed, allowing token", error=str(e)) # Allow token if check fails (fail open for availability) # Check if token is near expiry and set flag for response header if request: import time exp_time = payload.get("exp", 0) current_time = time.time() time_until_expiry = exp_time - current_time if time_until_expiry < 300: # 5 minutes request.state.token_near_expiry = True # Convert JWT payload to user context format return self._jwt_payload_to_user_context(payload) except Exception as e: logger.debug(f"Local token validation failed: {e}") # Strategy 2: Check cache for recently validated tokens if self.redis_client: try: cached_user = await self._get_cached_user(token) if cached_user: logger.debug("Token found in cache") return cached_user except Exception as e: logger.warning(f"Cache lookup failed: {e}") # Strategy 3: Verify with auth service (authoritative) try: user_context = await self._verify_with_auth_service(token) if user_context: # Cache successful validation if self.redis_client: await self._cache_user(token, user_context) logger.debug("Token validated by auth service") return user_context except Exception as e: logger.error(f"Auth service validation failed: {e}") return None def _validate_token_payload(self, payload: Dict[str, Any]) -> bool: """ Validate JWT payload has required fields FIXED: Updated to match actual token structure from auth service """ required_fields = ["user_id", "email", "exp", "type"] missing_fields = [field for field in required_fields if field not in payload] if missing_fields: logger.warning(f"Token payload missing fields: {missing_fields}") return False # Validate token type token_type = payload.get("type") if token_type not in ["access", "service"]: logger.warning(f"Invalid token type: {payload.get('type')}") return False # Check if token is near expiry (within 5 minutes) and log warning import time exp_time = payload.get("exp", 0) current_time = time.time() time_until_expiry = exp_time - current_time if time_until_expiry < 300: # 5 minutes logger.warning(f"Token expires in {int(time_until_expiry)} seconds for user {payload.get('email')}") # NEW: Check token freshness for subscription changes if payload.get("tenant_id"): try: # Note: We can't await here because this is a sync function # Token freshness will be checked in the async dispatch method # For now, just log that we would check freshness logger.debug("Token freshness check would be performed in async context", tenant_id=payload.get("tenant_id")) except Exception as e: logger.warning("Token freshness check setup failed", error=str(e)) return True def _validate_jwt_integrity(self, payload: Dict[str, Any]) -> bool: """ Validate JWT payload integrity beyond signature verification. Prevents edge cases where payload might be malformed. """ # Required fields must exist required_fields = ["user_id", "email", "exp", "iat", "iss"] if not all(field in payload for field in required_fields): logger.warning("JWT missing required fields", missing=[f for f in required_fields if f not in payload]) return False # Issuer must be our auth service if payload.get("iss") != "bakery-auth": logger.warning("JWT has invalid issuer", issuer=payload.get("iss")) return False # Token type must be valid if payload.get("type") not in ["access", "service"]: logger.warning("JWT has invalid type", token_type=payload.get("type")) return False # Subscription tier must be valid if present valid_tiers = ["starter", "professional", "enterprise"] if payload.get("subscription"): tier = payload["subscription"].get("tier", "").lower() if tier and tier not in valid_tiers: logger.warning("JWT has invalid subscription tier", tier=tier) return False return True async def _verify_token_freshness(self, payload: Dict[str, Any], tenant_id: str) -> bool: """ Verify token was issued after the last subscription change. Prevents use of stale tokens with old subscription data. """ if not self.redis_client: return True # Skip check if no Redis try: subscription_changed_at = await self.redis_client.get( f"tenant:{tenant_id}:subscription_changed_at" ) if subscription_changed_at: changed_timestamp = float(subscription_changed_at) token_issued_at = payload.get("iat", 0) if token_issued_at < changed_timestamp: logger.warning( "Token issued before subscription change", token_iat=token_issued_at, subscription_changed=changed_timestamp, tenant_id=tenant_id ) return False # Token is stale except Exception as e: logger.warning("Failed to check token freshness", error=str(e)) return True def _jwt_payload_to_user_context(self, payload: Dict[str, Any]) -> Dict[str, Any]: """ Convert JWT payload to user context format FIXED: Proper mapping between JWT structure and user context """ # NEW: Validate JWT integrity before processing if not self._validate_jwt_integrity(payload): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid JWT payload" ) base_context = { "user_id": payload["user_id"], "email": payload["email"], "exp": payload["exp"], "valid": True, "role": payload.get("role", "user"), } # NEW: Extract subscription from JWT if payload.get("tenant_id"): base_context["tenant_id"] = payload["tenant_id"] base_context["tenant_role"] = payload.get("tenant_role", "member") if payload.get("subscription"): sub = payload["subscription"] base_context["subscription_tier"] = sub.get("tier", "starter") base_context["subscription_status"] = sub.get("status", "active") base_context["subscription_from_jwt"] = True # Flag to skip HTTP if payload.get("tenant_access"): base_context["tenant_access"] = payload["tenant_access"] if payload.get("service"): service_name = payload["service"] base_context["service"] = service_name base_context["type"] = "service" base_context["role"] = "admin" base_context["user_id"] = f"{service_name}-service" base_context["email"] = f"{service_name}-service@internal" logger.debug(f"Service authentication: {payload['service']}") return base_context async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]: """ Verify token with auth service FIXED: Improved error handling and response parsing """ try: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.post( f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify", headers={"Authorization": f"Bearer {token}"} ) if response.status_code == 200: auth_response = response.json() # Validate auth service response structure if auth_response.get("valid") and auth_response.get("user_id"): return { "user_id": auth_response["user_id"], "email": auth_response["email"], "exp": auth_response.get("exp"), "valid": True } else: logger.warning(f"Auth service returned invalid response: {auth_response}") return None else: logger.warning(f"Auth service returned {response.status_code}: {response.text}") return None except httpx.TimeoutException: logger.error("Auth service timeout during token verification") return None except Exception as e: logger.error(f"Auth service error: {e}") return None async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]: """ Get user context from cache FIXED: Better error handling and JSON parsing """ if not self.redis_client: return None cache_key = f"auth:token:{hash(token) % 1000000}" # Use modulo for shorter keys try: cached_data = await self.redis_client.get(cache_key) if cached_data: if isinstance(cached_data, bytes): cached_data = cached_data.decode() return json.loads(cached_data) except json.JSONDecodeError as e: logger.warning(f"Failed to parse cached user data: {e}") except Exception as e: logger.warning(f"Cache lookup error: {e}") return None async def _cache_user(self, token: str, user_context: Dict[str, Any]) -> None: """ Cache user context FIXED: Better error handling and expiration """ if not self.redis_client: return cache_key = f"auth:token:{hash(token) % 1000000}" try: # Cache for 5 minutes (shorter than token expiry) await self.redis_client.setex( cache_key, 300, # 5 minutes json.dumps(user_context) ) except Exception as e: logger.warning(f"Failed to cache user context: {e}") async def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None): """ Inject user and tenant context headers for downstream services ENHANCED: Added logging to verify header injection """ # Log what we're injecting for debugging logger.debug( "Injecting context headers", user_id=user_context.get("user_id"), user_type=user_context.get("type", ""), service_name=user_context.get("service", ""), role=user_context.get("role", ""), tenant_id=tenant_id, path=request.url.path ) # Add user context headers request.headers.__dict__["_list"].append(( b"x-user-id", user_context["user_id"].encode() )) request.headers.__dict__["_list"].append(( b"x-user-email", user_context["email"].encode() )) user_role = user_context.get("role", "user") request.headers.__dict__["_list"].append(( b"x-user-role", user_role.encode() )) user_type = user_context.get("type", "") if user_type: request.headers.__dict__["_list"].append(( b"x-user-type", user_type.encode() )) service_name = user_context.get("service", "") if service_name: request.headers.__dict__["_list"].append(( b"x-service-name", service_name.encode() )) # Add tenant context if available if tenant_id: request.headers.__dict__["_list"].append(( b"x-tenant-id", tenant_id.encode() )) # Add subscription tier if available subscription_tier = user_context.get("subscription_tier", "") if subscription_tier: request.headers.__dict__["_list"].append(( b"x-subscription-tier", subscription_tier.encode() )) # Add is_demo flag for demo sessions is_demo = user_context.get("is_demo", False) if is_demo: request.headers.__dict__["_list"].append(( b"x-is-demo", b"true" )) # Add demo session context headers for backend services demo_session_id = user_context.get("demo_session_id", "") if demo_session_id: request.headers.__dict__["_list"].append(( b"x-demo-session-id", demo_session_id.encode() )) demo_account_type = user_context.get("demo_account_type", "") if demo_account_type: request.headers.__dict__["_list"].append(( b"x-demo-account-type", demo_account_type.encode() )) # Add hierarchical access headers if tenant context exists if tenant_id: tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct') can_view_children = getattr(request.state, 'can_view_children', False) request.headers.__dict__["_list"].append(( b"x-tenant-access-type", tenant_access_type.encode() )) request.headers.__dict__["_list"].append(( b"x-can-view-children", str(can_view_children).encode() )) # If this is hierarchical access, include parent tenant ID # Get parent tenant ID from the auth service if available try: import httpx async with httpx.AsyncClient(timeout=3.0) as client: response = await client.get( f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/hierarchy", headers={"Authorization": request.headers.get("Authorization", "")} ) if response.status_code == 200: hierarchy_data = response.json() parent_tenant_id = hierarchy_data.get("parent_tenant_id") if parent_tenant_id: request.headers.__dict__["_list"].append(( b"x-parent-tenant-id", parent_tenant_id.encode() )) except Exception as e: logger.warning(f"Failed to get parent tenant ID: {e}") pass # Add gateway identification request.headers.__dict__["_list"].append(( b"x-forwarded-by", b"bakery-gateway" )) async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]: """ Get tenant subscription tier using fast cached endpoint Args: tenant_id: Tenant ID request: FastAPI request for headers Returns: Subscription tier string or None """ try: # Use fast cached subscription tier endpoint (has its own Redis caching) async with httpx.AsyncClient(timeout=3.0) as client: headers = {"Authorization": request.headers.get("Authorization", "")} response = await client.get( f"{settings.TENANT_SERVICE_URL}/api/v1/subscriptions/{tenant_id}/tier", headers=headers ) if response.status_code == 200: tier_data = response.json() subscription_tier = tier_data.get("tier", "starter") logger.debug("Subscription tier from cached endpoint", tenant_id=tenant_id, tier=subscription_tier, cached=tier_data.get("cached", False)) return subscription_tier else: logger.warning(f"Failed to get tenant subscription tier: {response.status_code}") return "starter" # Default to starter except Exception as e: logger.error(f"Error getting tenant subscription tier: {e}") return "starter" # Default to starter on error