""" API Gateway - Central entry point for all microservices Handles routing, authentication, rate limiting, and cross-cutting concerns """ import asyncio import json import structlog import resource import os import time from fastapi import Request, HTTPException, WebSocket, WebSocketDisconnect from fastapi.responses import StreamingResponse import httpx from shared.redis_utils import initialize_redis, close_redis, get_redis_client from shared.service_base import StandardFastAPIService from app.core.config import settings from app.core.header_manager import header_manager from app.middleware.request_id import RequestIDMiddleware from app.middleware.auth import AuthMiddleware from app.middleware.logging import LoggingMiddleware from app.middleware.rate_limit import RateLimitMiddleware from app.middleware.rate_limiting import APIRateLimitMiddleware from app.middleware.subscription import SubscriptionMiddleware from app.middleware.demo_middleware import DemoMiddleware from app.middleware.read_only_mode import ReadOnlyModeMiddleware from app.routes import auth, tenant, registration, nominatim, subscription, demo, pos, geocoding, poi_context, webhooks # Initialize logger logger = structlog.get_logger() # Check file descriptor limits try: soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) if soft_limit < 1024: logger.warning(f"Low file descriptor limit detected: {soft_limit}") else: logger.info(f"File descriptor limit: {soft_limit} (sufficient)") except Exception as e: logger.debug(f"Could not check file descriptor limits: {e}") # Global Redis client for SSE streaming redis_client = None class GatewayService(StandardFastAPIService): """Gateway Service with standardized monitoring setup""" def __init__(self, **kwargs): super().__init__(**kwargs) # Initialize HeaderManager early header_manager.initialize() logger.info("HeaderManager initialized") # Initialize Redis during service creation so it's available when needed try: # We need to run the async initialization in a sync context import asyncio try: # Check if there's already a running event loop loop = asyncio.get_running_loop() # If there is, we'll initialize Redis later in on_startup self.redis_initialized = False self.redis_client = None except RuntimeError: # No event loop running, safe to run the async function import asyncio import nest_asyncio nest_asyncio.apply() # Allow nested event loops async def init_redis(): await initialize_redis(settings.REDIS_URL, db=0, max_connections=50) return await get_redis_client() self.redis_client = asyncio.run(init_redis()) self.redis_initialized = True logger.info("Connected to Redis for SSE streaming") except Exception as e: logger.error(f"Failed to initialize Redis during service creation: {e}") self.redis_initialized = False self.redis_client = None async def on_startup(self, app): """Custom startup logic for Gateway""" global redis_client # Initialize Redis if not already done during service creation if not self.redis_initialized: try: await initialize_redis(settings.REDIS_URL, db=0, max_connections=50) self.redis_client = await get_redis_client() redis_client = self.redis_client # Update global variable self.redis_initialized = True logger.info("Connected to Redis for SSE streaming") except Exception as e: logger.error(f"Failed to connect to Redis during startup: {e}") # Register custom metrics for gateway-specific operations if self.telemetry_providers and self.telemetry_providers.app_metrics: logger.info("Gateway-specific metrics tracking enabled") await super().on_startup(app) async def on_shutdown(self, app): """Custom shutdown logic for Gateway""" await super().on_shutdown(app) # Close Redis await close_redis() logger.info("Redis connection closed") # Create service instance service = GatewayService( service_name="gateway", app_name="Bakery Forecasting API Gateway", description="Central API Gateway for bakery forecasting microservices", version="1.0.0", log_level=getattr(settings, 'LOG_LEVEL', 'INFO'), cors_origins=settings.CORS_ORIGINS_LIST, enable_metrics=True, enable_health_checks=True, enable_tracing=True, enable_cors=True ) # Create FastAPI app app = service.create_app() # Add API rate limiting middleware with Redis client - this needs to be done after app creation # but before other middleware that might depend on it # Wait for Redis to be initialized if not already done if not hasattr(service, 'redis_client') or not service.redis_client: # Wait briefly for Redis initialization to complete import time time.sleep(1) # Check again after allowing time for initialization if hasattr(service, 'redis_client') and service.redis_client: app.add_middleware(APIRateLimitMiddleware, redis_client=service.redis_client) logger.info("API rate limiting middleware enabled") else: logger.warning("Redis client not available for API rate limiting middleware") else: app.add_middleware(APIRateLimitMiddleware, redis_client=service.redis_client) logger.info("API rate limiting middleware enabled") # Add gateway-specific middleware (in REVERSE order of execution) # Execution order: RequestIDMiddleware -> DemoMiddleware -> AuthMiddleware -> ReadOnlyModeMiddleware -> SubscriptionMiddleware -> APIRateLimitMiddleware -> RateLimitMiddleware -> LoggingMiddleware app.add_middleware(LoggingMiddleware) app.add_middleware(RateLimitMiddleware, calls_per_minute=300) app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) app.add_middleware(ReadOnlyModeMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) app.add_middleware(AuthMiddleware) app.add_middleware(DemoMiddleware) app.add_middleware(RequestIDMiddleware) # Include routers app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"]) app.include_router(registration.router, prefix="/api/v1", tags=["registration"]) app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"]) app.include_router(subscription.router, prefix="/api/v1", tags=["subscriptions"]) # Notification routes are now handled by tenant router at /api/v1/tenants/{tenant_id}/notifications/* app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"]) app.include_router(geocoding.router, prefix="/api/v1/geocoding", tags=["geocoding"]) app.include_router(pos.router, prefix="/api/v1/pos", tags=["pos"]) app.include_router(demo.router, prefix="/api/v1", tags=["demo"]) # Include webhooks at the root level to handle /api/v1/webhooks/* # Webhook routes are defined with full /api/v1/webhooks/* paths for consistency app.include_router(webhooks.router, prefix="", tags=["webhooks"]) # ================================================================ # SERVER-SENT EVENTS (SSE) HELPER FUNCTIONS # ================================================================ def _get_subscription_channels(tenant_id: str, channel_filters: list) -> list: """Determine which Redis channels to subscribe to based on filters""" all_domains = ["inventory", "production", "supply_chain", "demand", "operations"] all_classes = ["alerts", "notifications"] channels = [] if not channel_filters: # Subscribe to ALL channels (backward compatible) for domain in all_domains: for event_class in all_classes: channels.append(f"tenant:{tenant_id}:{domain}.{event_class}") channels.append(f"tenant:{tenant_id}:recommendations") channels.append(f"alerts:{tenant_id}") # Legacy return channels # Parse filters and expand wildcards for filter_pattern in channel_filters: if filter_pattern == "*.*": for domain in all_domains: for event_class in all_classes: channels.append(f"tenant:{tenant_id}:{domain}.{event_class}") channels.append(f"tenant:{tenant_id}:recommendations") elif filter_pattern.endswith(".*"): domain = filter_pattern.split(".")[0] for event_class in all_classes: channels.append(f"tenant:{tenant_id}:{domain}.{event_class}") elif filter_pattern.startswith("*."): event_class = filter_pattern.split(".")[1] if event_class == "recommendations": channels.append(f"tenant:{tenant_id}:recommendations") else: for domain in all_domains: channels.append(f"tenant:{tenant_id}:{domain}.{event_class}") elif filter_pattern == "recommendations": channels.append(f"tenant:{tenant_id}:recommendations") else: channels.append(f"tenant:{tenant_id}:{filter_pattern}") return list(set(channels)) async def _load_initial_state(redis_client, tenant_id: str, channel_filters: list) -> list: """Load initial state from Redis cache based on channel filters""" initial_events = [] try: if not channel_filters: # Legacy cache legacy_cache_key = f"active_alerts:{tenant_id}" cached_data = await redis_client.get(legacy_cache_key) if cached_data: return json.loads(cached_data) # New domain-specific caches all_domains = ["inventory", "production", "supply_chain", "demand", "operations"] all_classes = ["alerts", "notifications"] for domain in all_domains: for event_class in all_classes: cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s" cached_data = await redis_client.get(cache_key) if cached_data: initial_events.extend(json.loads(cached_data)) # Recommendations recommendations_cache_key = f"active_events:{tenant_id}:recommendations" cached_data = await redis_client.get(recommendations_cache_key) if cached_data: initial_events.extend(json.loads(cached_data)) return initial_events # Load based on specific filters for filter_pattern in channel_filters: if "." in filter_pattern: parts = filter_pattern.split(".") domain = parts[0] if parts[0] != "*" else None event_class = parts[1] if len(parts) > 1 and parts[1] != "*" else None if domain and event_class: cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s" cached_data = await redis_client.get(cache_key) if cached_data: initial_events.extend(json.loads(cached_data)) elif domain and not event_class: for ec in ["alerts", "notifications"]: cache_key = f"active_events:{tenant_id}:{domain}.{ec}" cached_data = await redis_client.get(cache_key) if cached_data: initial_events.extend(json.loads(cached_data)) elif not domain and event_class: all_domains = ["inventory", "production", "supply_chain", "demand", "operations"] for d in all_domains: cache_key = f"active_events:{tenant_id}:{d}.{event_class}s" cached_data = await redis_client.get(cache_key) if cached_data: initial_events.extend(json.loads(cached_data)) elif filter_pattern == "recommendations": cache_key = f"active_events:{tenant_id}:recommendations" cached_data = await redis_client.get(cache_key) if cached_data: initial_events.extend(json.loads(cached_data)) return initial_events except Exception as e: logger.error(f"Error loading initial state for tenant {tenant_id}: {e}") return [] def _determine_event_type(event_data: dict) -> str: """Determine SSE event type from event data""" if 'event_class' in event_data: return event_data['event_class'] if 'item_type' in event_data: if event_data['item_type'] == 'recommendation': return 'recommendation' else: return 'alert' return 'alert' # ================================================================ # SERVER-SENT EVENTS (SSE) ENDPOINT # ================================================================ @app.get("/api/v1/events") async def events_stream( request: Request, tenant_id: str, channels: str = None ): """ Server-Sent Events stream for real-time notifications with multi-channel support. Query Parameters: tenant_id: Tenant identifier (required) channels: Comma-separated channel filters (optional) """ global redis_client if not redis_client: raise HTTPException(status_code=503, detail="SSE service unavailable") # Extract user context from request state user_context = request.state.user user_id = user_context.get('user_id') email = user_context.get('email') if not tenant_id: raise HTTPException(status_code=400, detail="tenant_id query parameter is required") # Parse channel filters channel_filters = [] if channels: channel_filters = [c.strip() for c in channels.split(',') if c.strip()] logger.info(f"SSE connection request for user {email}, tenant {tenant_id}, channels: {channel_filters or 'all'}") async def event_generator(): """Generate server-sent events from Redis pub/sub""" pubsub = None try: pubsub = redis_client.pubsub() logger.debug(f"Created Redis pubsub connection for tenant: {tenant_id}") # Determine channels subscription_channels = _get_subscription_channels(tenant_id, channel_filters) # Subscribe if subscription_channels: await pubsub.subscribe(*subscription_channels) logger.info(f"Subscribed to {len(subscription_channels)} channels for tenant {tenant_id}") else: legacy_channel = f"alerts:{tenant_id}" await pubsub.subscribe(legacy_channel) # Connection event yield f"event: connection\n" yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'channels': subscription_channels or ['all'], 'timestamp': time.time()})}\n\n" # Initial state initial_events = await _load_initial_state(redis_client, tenant_id, channel_filters) if initial_events: logger.info(f"Sending {len(initial_events)} initial events to tenant {tenant_id}") yield f"event: initial_state\n" yield f"data: {json.dumps(initial_events)}\n\n" heartbeat_counter = 0 while True: if await request.is_disconnected(): logger.info(f"SSE client disconnected for tenant: {tenant_id}") break try: message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0) if message and message['type'] == 'message': event_data = json.loads(message['data']) event_type = _determine_event_type(event_data) event_data['_channel'] = message['channel'].decode('utf-8') if isinstance(message['channel'], bytes) else message['channel'] yield f"event: {event_type}\n" yield f"data: {json.dumps(event_data)}\n\n" except asyncio.TimeoutError: heartbeat_counter += 1 if heartbeat_counter >= 10: yield f"event: heartbeat\n" yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n" heartbeat_counter = 0 except asyncio.CancelledError: logger.info(f"SSE connection cancelled for tenant: {tenant_id}") except Exception as e: logger.error(f"SSE error for tenant {tenant_id}: {e}", exc_info=True) finally: if pubsub: try: await pubsub.unsubscribe() await pubsub.close() except Exception as e: logger.error(f"Error closing pubsub: {e}") logger.info(f"SSE connection closed for tenant: {tenant_id}") return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "Cache-Control", } ) # ================================================================ # WEBSOCKET ROUTING FOR TRAINING SERVICE # ================================================================ @app.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live") async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str): """WebSocket proxy with token verification for training service""" token = websocket.query_params.get("token") if not token: await websocket.accept() await websocket.close(code=1008, reason="Authentication token required") return # Verify token from shared.auth.jwt_handler import JWTHandler jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM) try: payload = jwt_handler.verify_token(token) if not payload or not payload.get('user_id'): await websocket.accept() await websocket.close(code=1008, reason="Invalid token") return except Exception as e: await websocket.accept() await websocket.close(code=1008, reason="Token verification failed") return await websocket.accept() # Build WebSocket URL to training service training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/') training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://') training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}" training_ws = None try: import websockets from websockets.protocol import State training_ws = await websockets.connect( training_ws_url, ping_interval=120, ping_timeout=60, close_timeout=60, open_timeout=30 ) async def forward_frontend_to_training(): try: while training_ws and training_ws.state == State.OPEN: data = await websocket.receive() if data.get("type") == "websocket.receive": if "text" in data: await training_ws.send(data["text"]) elif "bytes" in data: await training_ws.send(data["bytes"]) elif data.get("type") == "websocket.disconnect": break except Exception: pass async def forward_training_to_frontend(): try: while training_ws and training_ws.state == State.OPEN: message = await training_ws.recv() await websocket.send_text(message) except Exception: pass await asyncio.gather( forward_frontend_to_training(), forward_training_to_frontend(), return_exceptions=True ) except Exception as e: logger.error("WebSocket proxy error", job_id=job_id, error=str(e)) finally: if training_ws and training_ws.state == State.OPEN: try: await training_ws.close() except: pass try: if not websocket.client_state.name == 'DISCONNECTED': await websocket.close(code=1000, reason="Proxy closed") except: pass if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)