""" API Gateway - Central entry point for all microservices Handles routing, authentication, rate limiting, and cross-cutting concerns """ import asyncio import structlog from fastapi import FastAPI, Request, HTTPException, Depends, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse import httpx import time import redis.asyncio as aioredis from typing import Dict, Any from app.core.config import settings from app.core.service_discovery import ServiceDiscovery from app.middleware.auth import AuthMiddleware from app.middleware.logging import LoggingMiddleware from app.middleware.rate_limit import RateLimitMiddleware from app.routes import auth, tenant, notification, nominatim, user from shared.monitoring.logging import setup_logging from shared.monitoring.metrics import MetricsCollector # Setup logging setup_logging("gateway", settings.LOG_LEVEL) logger = structlog.get_logger() # Create FastAPI app app = FastAPI( title="Bakery Forecasting API Gateway", description="Central API Gateway for bakery forecasting microservices", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) # Initialize metrics collector metrics_collector = MetricsCollector("gateway") # Service discovery service_discovery = ServiceDiscovery() # Redis client for SSE streaming redis_client = None # CORS middleware - Add first app.add_middleware( CORSMiddleware, allow_origins=settings.CORS_ORIGINS_LIST, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Custom middleware - Add in correct order (outer to inner) app.add_middleware(LoggingMiddleware) app.add_middleware(RateLimitMiddleware, calls_per_minute=300) app.add_middleware(AuthMiddleware) # Include routers app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"]) app.include_router(user.router, prefix="/api/v1/users", tags=["users"]) app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"]) app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"]) app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"]) @app.on_event("startup") async def startup_event(): """Application startup""" global redis_client logger.info("Starting API Gateway") # Connect to Redis for SSE streaming try: redis_client = aioredis.from_url(settings.REDIS_URL) logger.info("Connected to Redis for SSE streaming") except Exception as e: logger.error(f"Failed to connect to Redis: {e}") metrics_collector.register_counter( "gateway_auth_requests_total", "Total authentication requests" ) metrics_collector.register_counter( "gateway_auth_responses_total", "Total authentication responses" ) metrics_collector.register_counter( "gateway_auth_errors_total", "Total authentication errors" ) metrics_collector.register_histogram( "gateway_request_duration_seconds", "Request duration in seconds" ) logger.info("Metrics registered successfully") metrics_collector.start_metrics_server(8080) logger.info("API Gateway started successfully") @app.on_event("shutdown") async def shutdown_event(): """Application shutdown""" global redis_client logger.info("Shutting down API Gateway") # Close Redis connection if redis_client: await redis_client.close() # Clean up service discovery # await service_discovery.cleanup() logger.info("API Gateway shutdown complete") @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "service": "api-gateway", "version": "1.0.0", "timestamp": time.time() } @app.get("/metrics") async def metrics(): """Metrics endpoint for monitoring""" return {"metrics": "enabled"} # ================================================================ # SERVER-SENT EVENTS (SSE) ENDPOINT # ================================================================ @app.get("/api/events") async def events_stream(request: Request, token: str): """Server-Sent Events stream for real-time notifications""" global redis_client if not redis_client: raise HTTPException(status_code=503, detail="SSE service unavailable") # Extract tenant_id from JWT token (basic extraction - you might want proper JWT validation) try: import jwt import base64 import json as json_lib # Decode JWT without verification for tenant_id (in production, verify the token) payload = jwt.decode(token, options={"verify_signature": False}) tenant_id = payload.get('tenant_id') user_id = payload.get('user_id') if not tenant_id: raise HTTPException(status_code=401, detail="Invalid token: missing tenant_id") except Exception as e: logger.error(f"Token decode error: {e}") raise HTTPException(status_code=401, detail="Invalid token") logger.info(f"SSE connection established for tenant: {tenant_id}") async def event_generator(): """Generate server-sent events from Redis pub/sub""" pubsub = None try: # Subscribe to tenant-specific alert channel pubsub = redis_client.pubsub() channel_name = f"alerts:{tenant_id}" await pubsub.subscribe(channel_name) # Send initial connection event yield f"event: connection\n" yield f"data: {json_lib.dumps({'type': 'connected', 'message': 'SSE connection established', 'timestamp': time.time()})}\n\n" heartbeat_counter = 0 while True: # Check if client has disconnected if await request.is_disconnected(): logger.info(f"SSE client disconnected for tenant: {tenant_id}") break try: # Get message from Redis with timeout message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0) if message and message['type'] == 'message': # Forward the alert/notification from Redis alert_data = json_lib.loads(message['data']) # Determine event type based on alert data event_type = "notification" if alert_data.get('item_type') == 'alert': if alert_data.get('severity') in ['high', 'urgent']: event_type = "inventory_alert" else: event_type = "notification" elif alert_data.get('item_type') == 'recommendation': event_type = "notification" yield f"event: {event_type}\n" yield f"data: {json_lib.dumps(alert_data)}\n\n" logger.debug(f"SSE message sent to tenant {tenant_id}: {alert_data.get('title')}") except asyncio.TimeoutError: # Send heartbeat every 10 timeouts (100 seconds) heartbeat_counter += 1 if heartbeat_counter >= 10: yield f"event: heartbeat\n" yield f"data: {json_lib.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n" heartbeat_counter = 0 except asyncio.CancelledError: logger.info(f"SSE connection cancelled for tenant: {tenant_id}") except Exception as e: logger.error(f"SSE error for tenant {tenant_id}: {e}") finally: if pubsub: await pubsub.unsubscribe() await pubsub.close() logger.info(f"SSE connection closed for tenant: {tenant_id}") return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "Cache-Control", } ) # ================================================================ # WEBSOCKET ROUTING FOR TRAINING SERVICE # ================================================================ @app.websocket("/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live") async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str): """WebSocket proxy that forwards connections directly to training service""" await websocket.accept() # Get token from query params token = websocket.query_params.get("token") if not token: logger.warning(f"WebSocket connection rejected - missing token for job {job_id}") await websocket.close(code=1008, reason="Authentication token required") return logger.info(f"Proxying WebSocket connection to training service for job {job_id}, tenant {tenant_id}") # Build WebSocket URL to training service training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/') training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://') training_ws_url = f"{training_ws_url}/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}" try: # Connect to training service WebSocket import websockets async with websockets.connect(training_ws_url) as training_ws: logger.info(f"Connected to training service WebSocket for job {job_id}") async def forward_to_training(): """Forward messages from frontend to training service""" try: async for message in websocket.iter_text(): await training_ws.send(message) except Exception as e: logger.error(f"Error forwarding to training service: {e}") async def forward_to_frontend(): """Forward messages from training service to frontend""" try: async for message in training_ws: await websocket.send_text(message) except Exception as e: logger.error(f"Error forwarding to frontend: {e}") # Run both forwarding tasks concurrently await asyncio.gather( forward_to_training(), forward_to_frontend(), return_exceptions=True ) except Exception as e: logger.error(f"WebSocket proxy error for job {job_id}: {e}") try: await websocket.close(code=1011, reason="Training service connection failed") except: pass finally: logger.info(f"WebSocket proxy closed for job {job_id}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)