""" API Gateway - Central entry point for all microservices Handles routing, authentication, rate limiting, and cross-cutting concerns """ import asyncio import structlog from fastapi import FastAPI, Request, HTTPException, Depends, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse import httpx import time from typing import Dict, Any from app.core.config import settings from app.core.service_discovery import ServiceDiscovery from app.middleware.auth import AuthMiddleware from app.middleware.logging import LoggingMiddleware from app.middleware.rate_limit import RateLimitMiddleware from app.routes import auth, tenant, notification, nominatim, user from shared.monitoring.logging import setup_logging from shared.monitoring.metrics import MetricsCollector # Setup logging setup_logging("gateway", settings.LOG_LEVEL) logger = structlog.get_logger() # Create FastAPI app app = FastAPI( title="Bakery Forecasting API Gateway", description="Central API Gateway for bakery forecasting microservices", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) # Initialize metrics collector metrics_collector = MetricsCollector("gateway") # Service discovery service_discovery = ServiceDiscovery() # CORS middleware - Add first app.add_middleware( CORSMiddleware, allow_origins=settings.CORS_ORIGINS_LIST, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Custom middleware - Add in correct order (outer to inner) app.add_middleware(LoggingMiddleware) app.add_middleware(RateLimitMiddleware, calls_per_minute=300) app.add_middleware(AuthMiddleware) # Include routers app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"]) app.include_router(user.router, prefix="/api/v1/users", tags=["users"]) app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"]) app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"]) app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"]) @app.on_event("startup") async def startup_event(): """Application startup""" logger.info("Starting API Gateway") metrics_collector.register_counter( "gateway_auth_requests_total", "Total authentication requests" ) metrics_collector.register_counter( "gateway_auth_responses_total", "Total authentication responses" ) metrics_collector.register_counter( "gateway_auth_errors_total", "Total authentication errors" ) metrics_collector.register_histogram( "gateway_request_duration_seconds", "Request duration in seconds" ) logger.info("Metrics registered successfully") metrics_collector.start_metrics_server(8080) logger.info("API Gateway started successfully") @app.on_event("shutdown") async def shutdown_event(): """Application shutdown""" logger.info("Shutting down API Gateway") # Clean up service discovery # await service_discovery.cleanup() logger.info("API Gateway shutdown complete") @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "service": "api-gateway", "version": "1.0.0", "timestamp": time.time() } @app.get("/metrics") async def metrics(): """Metrics endpoint for monitoring""" return {"metrics": "enabled"} # ================================================================ # WEBSOCKET ROUTING FOR TRAINING SERVICE # ================================================================ @app.websocket("/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live") async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str): """WebSocket proxy for training progress updates""" await websocket.accept() # Get token from query params token = websocket.query_params.get("token") if not token: await websocket.close(code=1008, reason="Authentication token required") return # Build HTTP URL to training service (we'll use HTTP client to proxy) training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/') training_ws_url = f"{training_service_base}/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}" try: # Use HTTP client to connect to training service WebSocket async with httpx.AsyncClient() as client: # Since we can't easily proxy WebSocket with httpx, let's try a different approach # We'll make periodic HTTP requests to get training status logger.info(f"Starting WebSocket proxy for training job {job_id}") # Send initial connection confirmation await websocket.send_json({ "type": "connection_established", "job_id": job_id, "tenant_id": tenant_id }) # Poll for training updates last_status = None while True: try: # Make HTTP request to get current training status status_url = f"{training_service_base}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/status" response = await client.get( status_url, headers={"Authorization": f"Bearer {token}"}, timeout=5.0 ) if response.status_code == 200: current_status = response.json() # Only send update if status changed if current_status != last_status: await websocket.send_json({ "type": "training_progress", "data": current_status }) last_status = current_status # If training is completed or failed, we can stop polling if current_status.get('status') in ['completed', 'failed', 'cancelled']: await websocket.send_json({ "type": "training_" + current_status.get('status', 'completed'), "data": current_status }) break # Wait before next poll await asyncio.sleep(2) except WebSocketDisconnect: logger.info("WebSocket client disconnected") break except httpx.TimeoutException: # Continue polling even if request times out await asyncio.sleep(5) continue except Exception as e: logger.error(f"Error polling training status: {e}") await asyncio.sleep(5) continue except WebSocketDisconnect: logger.info("WebSocket client disconnected during setup") except Exception as e: logger.error(f"WebSocket proxy error: {e}") await websocket.close(code=1011, reason="Internal server error") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)