""" API Gateway - Central entry point for all microservices Handles routing, authentication, rate limiting, and cross-cutting concerns """ import asyncio import json import structlog from fastapi import FastAPI, Request, HTTPException, Depends, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse import httpx import time import redis.asyncio as aioredis from typing import Dict, Any from app.core.config import settings from app.core.service_discovery import ServiceDiscovery from app.middleware.auth import AuthMiddleware from app.middleware.logging import LoggingMiddleware from app.middleware.rate_limit import RateLimitMiddleware from app.middleware.subscription import SubscriptionMiddleware from app.middleware.demo_middleware import DemoMiddleware from app.routes import auth, tenant, notification, nominatim, user, subscription, demo from shared.monitoring.logging import setup_logging from shared.monitoring.metrics import MetricsCollector # Setup logging setup_logging("gateway", settings.LOG_LEVEL) logger = structlog.get_logger() # Create FastAPI app app = FastAPI( title="Bakery Forecasting API Gateway", description="Central API Gateway for bakery forecasting microservices", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) # Initialize metrics collector metrics_collector = MetricsCollector("gateway") # Service discovery service_discovery = ServiceDiscovery() # Redis client for SSE streaming redis_client = None # CORS middleware - Add first app.add_middleware( CORSMiddleware, allow_origins=settings.CORS_ORIGINS_LIST, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Custom middleware - Add in REVERSE order (last added = first executed) # Execution order: DemoMiddleware -> AuthMiddleware -> SubscriptionMiddleware -> RateLimitMiddleware -> LoggingMiddleware app.add_middleware(LoggingMiddleware) # Executes 5th (outermost) app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 4th app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 3rd app.add_middleware(AuthMiddleware) # Executes 2nd - Checks for demo context app.add_middleware(DemoMiddleware) # Executes 1st (innermost) - Sets demo user context FIRST # Include routers app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"]) app.include_router(user.router, prefix="/api/v1/users", tags=["users"]) app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"]) app.include_router(subscription.router, prefix="/api/v1", tags=["subscriptions"]) app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"]) app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"]) app.include_router(demo.router, prefix="/api/v1", tags=["demo"]) @app.on_event("startup") async def startup_event(): """Application startup""" global redis_client logger.info("Starting API Gateway") # Connect to Redis for SSE streaming try: redis_client = aioredis.from_url(settings.REDIS_URL) logger.info("Connected to Redis for SSE streaming") except Exception as e: logger.error(f"Failed to connect to Redis: {e}") metrics_collector.register_counter( "gateway_auth_requests_total", "Total authentication requests" ) metrics_collector.register_counter( "gateway_auth_responses_total", "Total authentication responses" ) metrics_collector.register_counter( "gateway_auth_errors_total", "Total authentication errors" ) metrics_collector.register_histogram( "gateway_request_duration_seconds", "Request duration in seconds" ) logger.info("Metrics registered successfully") metrics_collector.start_metrics_server(8080) logger.info("API Gateway started successfully") @app.on_event("shutdown") async def shutdown_event(): """Application shutdown""" global redis_client logger.info("Shutting down API Gateway") # Close Redis connection if redis_client: await redis_client.close() # Clean up service discovery # await service_discovery.cleanup() logger.info("API Gateway shutdown complete") @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "service": "api-gateway", "version": "1.0.0", "timestamp": time.time() } @app.get("/metrics") async def metrics(): """Metrics endpoint for monitoring""" return {"metrics": "enabled"} # ================================================================ # SERVER-SENT EVENTS (SSE) ENDPOINT # ================================================================ @app.get("/api/events") async def events_stream(request: Request, tenant_id: str): """ Server-Sent Events stream for real-time notifications. Authentication is handled by auth middleware via query param token. User context is available in request.state.user (injected by middleware). Tenant ID is provided by the frontend as a query parameter. """ global redis_client if not redis_client: raise HTTPException(status_code=503, detail="SSE service unavailable") # Extract user context from request state (set by auth middleware) user_context = request.state.user user_id = user_context.get('user_id') email = user_context.get('email') # Validate tenant_id parameter if not tenant_id: raise HTTPException(status_code=400, detail="tenant_id query parameter is required") logger.info(f"SSE connection request for user {email}, tenant {tenant_id}") logger.info(f"SSE connection established for tenant: {tenant_id}") async def event_generator(): """Generate server-sent events from Redis pub/sub""" pubsub = None try: # Subscribe to tenant-specific alert channel pubsub = redis_client.pubsub() channel_name = f"alerts:{tenant_id}" await pubsub.subscribe(channel_name) # Send initial connection event yield f"event: connection\n" yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'timestamp': time.time()})}\n\n" heartbeat_counter = 0 while True: # Check if client has disconnected if await request.is_disconnected(): logger.info(f"SSE client disconnected for tenant: {tenant_id}") break try: # Get message from Redis with timeout message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0) if message and message['type'] == 'message': # Forward the alert/notification from Redis alert_data = json.loads(message['data']) # Determine event type based on alert data event_type = "notification" if alert_data.get('item_type') == 'alert': if alert_data.get('severity') in ['high', 'urgent']: event_type = "inventory_alert" else: event_type = "notification" elif alert_data.get('item_type') == 'recommendation': event_type = "notification" yield f"event: {event_type}\n" yield f"data: {json.dumps(alert_data)}\n\n" logger.debug(f"SSE message sent to tenant {tenant_id}: {alert_data.get('title')}") except asyncio.TimeoutError: # Send heartbeat every 10 timeouts (100 seconds) heartbeat_counter += 1 if heartbeat_counter >= 10: yield f"event: heartbeat\n" yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n" heartbeat_counter = 0 except asyncio.CancelledError: logger.info(f"SSE connection cancelled for tenant: {tenant_id}") except Exception as e: logger.error(f"SSE error for tenant {tenant_id}: {e}") finally: if pubsub: await pubsub.unsubscribe() await pubsub.close() logger.info(f"SSE connection closed for tenant: {tenant_id}") return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "Cache-Control", } ) # ================================================================ # WEBSOCKET ROUTING FOR TRAINING SERVICE # ================================================================ @app.websocket("/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live") async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str): """WebSocket proxy that forwards connections directly to training service with enhanced token validation""" await websocket.accept() # Get token from query params token = websocket.query_params.get("token") if not token: logger.warning(f"WebSocket connection rejected - missing token for job {job_id}") await websocket.close(code=1008, reason="Authentication token required") return # Validate token using auth middleware from app.middleware.auth import jwt_handler try: payload = jwt_handler.verify_token(token) if not payload: logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}") await websocket.close(code=1008, reason="Invalid authentication token") return # Check token expiration import time if payload.get('exp', 0) < time.time(): logger.warning(f"WebSocket connection rejected - expired token for job {job_id}") await websocket.close(code=1008, reason="Token expired") return logger.info(f"WebSocket token validated for user {payload.get('email', 'unknown')}") except Exception as e: logger.warning(f"WebSocket token validation failed for job {job_id}: {e}") await websocket.close(code=1008, reason="Token validation failed") return logger.info(f"Proxying WebSocket connection to training service for job {job_id}, tenant {tenant_id}") # Build WebSocket URL to training service training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/') training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://') training_ws_url = f"{training_ws_url}/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}" training_ws = None heartbeat_task = None try: # Connect to training service WebSocket with proper timeout configuration import websockets # Configure timeouts to coordinate with frontend (30s heartbeat) and training service # DISABLE gateway-level ping to avoid dual-ping conflicts - let frontend handle ping/pong training_ws = await websockets.connect( training_ws_url, ping_interval=None, # DISABLED: Let frontend handle ping/pong via message forwarding ping_timeout=None, # DISABLED: No independent ping mechanism close_timeout=15, # Reasonable close timeout max_size=2**20, # 1MB max message size max_queue=32 # Max queued messages ) logger.info(f"Connected to training service WebSocket for job {job_id} with gateway ping DISABLED (frontend handles ping/pong)") # Track connection state properly due to FastAPI WebSocket state propagation bug connection_alive = True last_activity = asyncio.get_event_loop().time() async def check_connection_health(): """Monitor connection health based on activity timestamps only - no WebSocket interference""" nonlocal connection_alive, last_activity while connection_alive: try: await asyncio.sleep(30) # Check every 30 seconds (aligned with frontend heartbeat) current_time = asyncio.get_event_loop().time() # Check if we haven't received any activity for too long # Frontend sends ping every 30s, so 90s = 3 missed pings before considering dead if current_time - last_activity > 90: logger.warning(f"No frontend activity for 90s on job {job_id} - connection may be dead") # Don't forcibly close - let the forwarding loops handle actual connection issues # This is just monitoring/logging now else: logger.debug(f"Connection health OK for job {job_id} - last activity {int(current_time - last_activity)}s ago") except Exception as e: logger.error(f"Connection health monitoring error for job {job_id}: {e}") break async def forward_to_training(): """Forward messages from frontend to training service with proper error handling""" nonlocal connection_alive, last_activity try: while connection_alive and training_ws and training_ws.open: try: # Use longer timeout to avoid conflicts with frontend 30s heartbeat # Frontend sends ping every 30s, so we need to allow for some latency message = await asyncio.wait_for(websocket.receive_text(), timeout=45.0) last_activity = asyncio.get_event_loop().time() # Forward the message to training service await training_ws.send(message) logger.debug(f"Forwarded message to training service for job {job_id}: {message[:100]}...") except asyncio.TimeoutError: # No message received in 45 seconds, continue loop # This allows for frontend 30s heartbeat + network latency + processing time continue except Exception as e: logger.error(f"Error receiving from frontend for job {job_id}: {e}") connection_alive = False break except Exception as e: logger.error(f"Error in forward_to_training for job {job_id}: {e}") connection_alive = False async def forward_to_frontend(): """Forward messages from training service to frontend with proper error handling""" nonlocal connection_alive, last_activity try: while connection_alive and training_ws and training_ws.open: try: # Use coordinated timeout - training service expects messages every 60s # This should be longer than training service timeout to avoid premature closure message = await asyncio.wait_for(training_ws.recv(), timeout=75.0) last_activity = asyncio.get_event_loop().time() # Forward the message to frontend await websocket.send_text(message) logger.debug(f"Forwarded message to frontend for job {job_id}: {message[:100]}...") except asyncio.TimeoutError: # No message received in 75 seconds, continue loop # Training service sends heartbeats, so this indicates potential issues continue except Exception as e: logger.error(f"Error receiving from training service for job {job_id}: {e}") connection_alive = False break except Exception as e: logger.error(f"Error in forward_to_frontend for job {job_id}: {e}") connection_alive = False # Start connection health monitoring heartbeat_task = asyncio.create_task(check_connection_health()) # Run both forwarding tasks concurrently with proper error handling try: await asyncio.gather( forward_to_training(), forward_to_frontend(), return_exceptions=True ) except Exception as e: logger.error(f"Error in WebSocket forwarding tasks for job {job_id}: {e}") finally: connection_alive = False except websockets.exceptions.ConnectionClosedError as e: logger.warning(f"Training service WebSocket connection closed for job {job_id}: {e}") except websockets.exceptions.WebSocketException as e: logger.error(f"WebSocket exception for job {job_id}: {e}") except Exception as e: logger.error(f"WebSocket proxy error for job {job_id}: {e}") finally: # Cleanup if heartbeat_task and not heartbeat_task.done(): heartbeat_task.cancel() try: await heartbeat_task except asyncio.CancelledError: pass if training_ws and not training_ws.closed: try: await training_ws.close() except Exception as e: logger.warning(f"Error closing training service WebSocket for job {job_id}: {e}") try: if not websocket.client_state.name == 'DISCONNECTED': await websocket.close(code=1000, reason="Proxy connection closed") except Exception as e: logger.warning(f"Error closing frontend WebSocket for job {job_id}: {e}") logger.info(f"WebSocket proxy cleanup completed for job {job_id}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)