""" API Gateway - Central entry point for all microservices Handles routing, authentication, rate limiting, and cross-cutting concerns """ import asyncio import json import structlog from fastapi import FastAPI, Request, HTTPException, Depends, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse import httpx import time from shared.redis_utils import initialize_redis, close_redis, get_redis_client from typing import Dict, Any from app.core.config import settings from app.middleware.request_id import RequestIDMiddleware from app.middleware.auth import AuthMiddleware from app.middleware.logging import LoggingMiddleware from app.middleware.rate_limit import RateLimitMiddleware from app.middleware.subscription import SubscriptionMiddleware from app.middleware.demo_middleware import DemoMiddleware from app.middleware.read_only_mode import ReadOnlyModeMiddleware from app.routes import auth, tenant, notification, nominatim, subscription, demo, pos, geocoding, poi_context from shared.monitoring.logging import setup_logging from shared.monitoring.metrics import MetricsCollector # Setup logging setup_logging("gateway", settings.LOG_LEVEL) logger = structlog.get_logger() # Create FastAPI app app = FastAPI( title="Bakery Forecasting API Gateway", description="Central API Gateway for bakery forecasting microservices", version="1.0.0", docs_url="/docs", redoc_url="/redoc", redirect_slashes=False # Disable automatic trailing slash redirects ) # Initialize metrics collector metrics_collector = MetricsCollector("gateway") # Redis client for SSE streaming redis_client = None # CORS middleware - Add first app.add_middleware( CORSMiddleware, allow_origins=settings.CORS_ORIGINS_LIST, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Custom middleware - Add in REVERSE order (last added = first executed) # Execution order: RequestIDMiddleware -> DemoMiddleware -> AuthMiddleware -> ReadOnlyModeMiddleware -> SubscriptionMiddleware -> RateLimitMiddleware -> LoggingMiddleware app.add_middleware(LoggingMiddleware) # Executes 7th (outermost) app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 6th app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 5th app.add_middleware(ReadOnlyModeMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 4th - Enforce read-only mode app.add_middleware(AuthMiddleware) # Executes 3rd - Checks for demo context app.add_middleware(DemoMiddleware) # Executes 2nd - Sets demo user context app.add_middleware(RequestIDMiddleware) # Executes 1st (innermost) - Generates request ID for tracing # Include routers app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"]) app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"]) app.include_router(subscription.router, prefix="/api/v1", tags=["subscriptions"]) app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"]) app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"]) app.include_router(geocoding.router, prefix="/api/v1/geocoding", tags=["geocoding"]) # app.include_router(poi_context.router, prefix="/api/v1/poi-context", tags=["poi-context"]) # Removed to implement tenant-based architecture app.include_router(pos.router, prefix="/api/v1/pos", tags=["pos"]) app.include_router(demo.router, prefix="/api/v1", tags=["demo"]) @app.on_event("startup") async def startup_event(): """Application startup""" global redis_client logger.info("Starting API Gateway") # Initialize shared Redis connection try: await initialize_redis(settings.REDIS_URL, db=0, max_connections=50) redis_client = await get_redis_client() logger.info("Connected to Redis for SSE streaming") except Exception as e: logger.error(f"Failed to connect to Redis: {e}") metrics_collector.register_counter( "gateway_auth_requests_total", "Total authentication requests" ) metrics_collector.register_counter( "gateway_auth_responses_total", "Total authentication responses" ) metrics_collector.register_counter( "gateway_auth_errors_total", "Total authentication errors" ) metrics_collector.register_histogram( "gateway_request_duration_seconds", "Request duration in seconds" ) logger.info("Metrics registered successfully") metrics_collector.start_metrics_server(8080) logger.info("API Gateway started successfully") @app.on_event("shutdown") async def shutdown_event(): """Application shutdown""" logger.info("Shutting down API Gateway") # Close shared Redis connection await close_redis() # Clean up service discovery # await service_discovery.cleanup() logger.info("API Gateway shutdown complete") @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "service": "api-gateway", "version": "1.0.0", "timestamp": time.time() } @app.get("/metrics") async def metrics(): """Metrics endpoint for monitoring""" return {"metrics": "enabled"} # ================================================================ # SERVER-SENT EVENTS (SSE) ENDPOINT # ================================================================ @app.get("/api/events") async def events_stream(request: Request, tenant_id: str): """ Server-Sent Events stream for real-time notifications. Authentication is handled by auth middleware via query param token. User context is available in request.state.user (injected by middleware). Tenant ID is provided by the frontend as a query parameter. """ global redis_client if not redis_client: raise HTTPException(status_code=503, detail="SSE service unavailable") # Extract user context from request state (set by auth middleware) user_context = request.state.user user_id = user_context.get('user_id') email = user_context.get('email') # Validate tenant_id parameter if not tenant_id: raise HTTPException(status_code=400, detail="tenant_id query parameter is required") logger.info(f"SSE connection request for user {email}, tenant {tenant_id}") logger.info(f"SSE connection established for tenant: {tenant_id}") async def event_generator(): """Generate server-sent events from Redis pub/sub""" pubsub = None try: # Subscribe to tenant-specific alert channel pubsub = redis_client.pubsub() channel_name = f"alerts:{tenant_id}" await pubsub.subscribe(channel_name) # Send initial connection event yield f"event: connection\n" yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'timestamp': time.time()})}\n\n" # Fetch and send initial active alerts from Redis cache try: cache_key = f"active_alerts:{tenant_id}" cached_alerts = await redis_client.get(cache_key) if cached_alerts: active_items = json.loads(cached_alerts) logger.info(f"Sending initial_items to tenant {tenant_id}, count: {len(active_items)}") yield f"event: initial_items\n" yield f"data: {json.dumps(active_items)}\n\n" else: logger.info(f"No cached alerts found for tenant {tenant_id}") yield f"event: initial_items\n" yield f"data: {json.dumps([])}\n\n" except Exception as e: logger.error(f"Error fetching initial items for tenant {tenant_id}: {e}") # Still send empty initial_items event yield f"event: initial_items\n" yield f"data: {json.dumps([])}\n\n" heartbeat_counter = 0 while True: # Check if client has disconnected if await request.is_disconnected(): logger.info(f"SSE client disconnected for tenant: {tenant_id}") break try: # Get message from Redis with timeout message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0) if message and message['type'] == 'message': # Forward the alert/notification from Redis alert_data = json.loads(message['data']) # Determine event type based on alert data event_type = "notification" if alert_data.get('item_type') == 'alert': if alert_data.get('severity') in ['high', 'urgent']: event_type = "inventory_alert" else: event_type = "notification" elif alert_data.get('item_type') == 'recommendation': event_type = "notification" yield f"event: {event_type}\n" yield f"data: {json.dumps(alert_data)}\n\n" logger.debug(f"SSE message sent to tenant {tenant_id}: {alert_data.get('title')}") except asyncio.TimeoutError: # Send heartbeat every 10 timeouts (100 seconds) heartbeat_counter += 1 if heartbeat_counter >= 10: yield f"event: heartbeat\n" yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n" heartbeat_counter = 0 except asyncio.CancelledError: logger.info(f"SSE connection cancelled for tenant: {tenant_id}") except Exception as e: logger.error(f"SSE error for tenant {tenant_id}: {e}") finally: if pubsub: await pubsub.unsubscribe() await pubsub.close() logger.info(f"SSE connection closed for tenant: {tenant_id}") return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "Cache-Control", } ) # ================================================================ # WEBSOCKET ROUTING FOR TRAINING SERVICE # ================================================================ @app.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live") async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str): """ Simple WebSocket proxy with token verification only. Validates the token and forwards the connection to the training service. """ # Get token from query params token = websocket.query_params.get("token") if not token: logger.warning("WebSocket proxy rejected - missing token", job_id=job_id, tenant_id=tenant_id) await websocket.accept() await websocket.close(code=1008, reason="Authentication token required") return # Verify token from shared.auth.jwt_handler import JWTHandler jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM) try: payload = jwt_handler.verify_token(token) if not payload or not payload.get('user_id'): logger.warning("WebSocket proxy rejected - invalid token", job_id=job_id, tenant_id=tenant_id) await websocket.accept() await websocket.close(code=1008, reason="Invalid token") return logger.info("WebSocket proxy - token verified", user_id=payload.get('user_id'), tenant_id=tenant_id, job_id=job_id) except Exception as e: logger.warning("WebSocket proxy - token verification failed", job_id=job_id, error=str(e)) await websocket.accept() await websocket.close(code=1008, reason="Token verification failed") return # Accept the connection await websocket.accept() # Build WebSocket URL to training service training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/') training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://') training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}" logger.info("Gateway proxying WebSocket to training service", job_id=job_id, training_ws_url=training_ws_url.replace(token, '***')) training_ws = None try: # Connect to training service WebSocket import websockets from websockets.protocol import State training_ws = await websockets.connect( training_ws_url, ping_interval=120, # Send ping every 2 minutes (tolerates long training operations) ping_timeout=60, # Wait up to 1 minute for pong (graceful timeout) close_timeout=60, # Increase close timeout for graceful shutdown open_timeout=30 ) logger.info("Gateway connected to training service WebSocket", job_id=job_id) async def forward_frontend_to_training(): """Forward messages from frontend to training service""" try: while training_ws and training_ws.state == State.OPEN: data = await websocket.receive() if data.get("type") == "websocket.receive": if "text" in data: await training_ws.send(data["text"]) elif "bytes" in data: await training_ws.send(data["bytes"]) elif data.get("type") == "websocket.disconnect": break except Exception as e: logger.debug("Frontend to training forward ended", error=str(e)) async def forward_training_to_frontend(): """Forward messages from training service to frontend""" message_count = 0 try: while training_ws and training_ws.state == State.OPEN: message = await training_ws.recv() await websocket.send_text(message) message_count += 1 # Log every 10th message to track connectivity if message_count % 10 == 0: logger.debug("WebSocket proxy active", job_id=job_id, messages_forwarded=message_count) except Exception as e: logger.info("Training to frontend forward ended", job_id=job_id, messages_forwarded=message_count, error=str(e)) # Run both forwarding tasks concurrently await asyncio.gather( forward_frontend_to_training(), forward_training_to_frontend(), return_exceptions=True ) except Exception as e: logger.error("WebSocket proxy error", job_id=job_id, error=str(e)) finally: # Cleanup if training_ws and training_ws.state == State.OPEN: try: await training_ws.close() except: pass try: if not websocket.client_state.name == 'DISCONNECTED': await websocket.close(code=1000, reason="Proxy closed") except: pass logger.info("WebSocket proxy connection closed", job_id=job_id) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)