""" API Gateway - Central entry point for all microservices Handles routing, authentication, rate limiting, and cross-cutting concerns """ import asyncio import json import structlog import resource import os from fastapi import FastAPI, Request, HTTPException, Depends, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse, Response import httpx import time from shared.redis_utils import initialize_redis, close_redis, get_redis_client from typing import Dict, Any from app.core.config import settings from app.middleware.request_id import RequestIDMiddleware from app.middleware.auth import AuthMiddleware from app.middleware.logging import LoggingMiddleware from app.middleware.rate_limit import RateLimitMiddleware from app.middleware.rate_limiting import APIRateLimitMiddleware from app.middleware.subscription import SubscriptionMiddleware from app.middleware.demo_middleware import DemoMiddleware from app.middleware.read_only_mode import ReadOnlyModeMiddleware from app.routes import auth, tenant, notification, nominatim, subscription, demo, pos, geocoding, poi_context from shared.monitoring.logging import setup_logging from shared.monitoring.metrics import MetricsCollector, add_metrics_middleware # OpenTelemetry imports from opentelemetry import trace from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor from opentelemetry.instrumentation.redis import RedisInstrumentor from opentelemetry.sdk.resources import Resource # Configure OpenTelemetry tracing def setup_tracing(service_name: str = "gateway"): """Initialize OpenTelemetry tracing with OTLP exporter for Jaeger""" # Create resource with service name resource = Resource.create({"service.name": service_name}) # Configure OTLP exporter (sends to OpenTelemetry Collector) otlp_exporter = OTLPSpanExporter( endpoint=os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://otel-collector.monitoring.svc.cluster.local:4317"), insecure=True # Use insecure connection for internal cluster communication ) # Configure tracer provider provider = TracerProvider(resource=resource) processor = BatchSpanProcessor(otlp_exporter) provider.add_span_processor(processor) # Set global tracer provider trace.set_tracer_provider(provider) return provider # Initialize tracing tracer_provider = setup_tracing("gateway") # Setup logging setup_logging("gateway", settings.LOG_LEVEL) logger = structlog.get_logger() # Check file descriptor limits and warn if too low try: soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) if soft_limit < 1024: logger.warning(f"Low file descriptor limit detected: {soft_limit}. Gateway may experience 'too many open files' errors.") logger.warning(f"Recommended: Increase limit with 'ulimit -n 4096' or higher for production.") if soft_limit < 256: logger.error(f"Critical: File descriptor limit ({soft_limit}) is too low for gateway operation!") else: logger.info(f"File descriptor limit: {soft_limit} (sufficient)") except Exception as e: logger.debug(f"Could not check file descriptor limits: {e}") # Check and log current working directory and permissions try: cwd = os.getcwd() logger.info(f"Current working directory: {cwd}") # Check if we can write to common log locations test_locations = ["/var/log", "./logs", "."] for location in test_locations: try: test_file = os.path.join(location, ".gateway_permission_test") with open(test_file, 'w') as f: f.write("test") os.remove(test_file) logger.info(f"Write permission confirmed for: {location}") except Exception as e: logger.warning(f"Cannot write to {location}: {e}") except Exception as e: logger.debug(f"Could not check directory permissions: {e}") # Create FastAPI app app = FastAPI( title="Bakery Forecasting API Gateway", description="Central API Gateway for bakery forecasting microservices", version="1.0.0", docs_url="/docs", redoc_url="/redoc", redirect_slashes=False # Disable automatic trailing slash redirects ) # Instrument FastAPI with OpenTelemetry FastAPIInstrumentor.instrument_app(app) # Instrument httpx for outgoing requests HTTPXClientInstrumentor().instrument() # Instrument Redis (will be active once redis client is initialized) RedisInstrumentor().instrument() # Initialize metrics collector metrics_collector = MetricsCollector("gateway") # Add metrics middleware to track HTTP requests add_metrics_middleware(app, metrics_collector) # Redis client for SSE streaming redis_client = None # CORS middleware - Add first app.add_middleware( CORSMiddleware, allow_origins=settings.CORS_ORIGINS_LIST, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Custom middleware - Add in REVERSE order (last added = first executed) # Execution order: RequestIDMiddleware -> DemoMiddleware -> AuthMiddleware -> ReadOnlyModeMiddleware -> SubscriptionMiddleware -> APIRateLimitMiddleware -> RateLimitMiddleware -> LoggingMiddleware app.add_middleware(LoggingMiddleware) # Executes 8th (outermost) app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 7th - Simple rate limit # Note: APIRateLimitMiddleware will be added on startup with Redis client app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 5th app.add_middleware(ReadOnlyModeMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 4th - Enforce read-only mode app.add_middleware(AuthMiddleware) # Executes 3rd - Checks for demo context app.add_middleware(DemoMiddleware) # Executes 2nd - Sets demo user context app.add_middleware(RequestIDMiddleware) # Executes 1st (innermost) - Generates request ID for tracing # Include routers app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"]) app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"]) app.include_router(subscription.router, prefix="/api/v1", tags=["subscriptions"]) app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"]) app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"]) app.include_router(geocoding.router, prefix="/api/v1/geocoding", tags=["geocoding"]) # app.include_router(poi_context.router, prefix="/api/v1/poi-context", tags=["poi-context"]) # Removed to implement tenant-based architecture app.include_router(pos.router, prefix="/api/v1/pos", tags=["pos"]) app.include_router(demo.router, prefix="/api/v1", tags=["demo"]) @app.on_event("startup") async def startup_event(): """Application startup""" global redis_client logger.info("Starting API Gateway") # Initialize shared Redis connection try: await initialize_redis(settings.REDIS_URL, db=0, max_connections=50) redis_client = await get_redis_client() logger.info("Connected to Redis for SSE streaming") # Add API rate limiting middleware with Redis client app.add_middleware(APIRateLimitMiddleware, redis_client=redis_client) logger.info("API rate limiting middleware enabled with subscription-based quotas") except Exception as e: logger.error(f"Failed to connect to Redis: {e}") logger.warning("API rate limiting middleware will fail open (allow all requests)") metrics_collector.register_counter( "gateway_auth_requests_total", "Total authentication requests" ) metrics_collector.register_counter( "gateway_auth_responses_total", "Total authentication responses" ) metrics_collector.register_counter( "gateway_auth_errors_total", "Total authentication errors" ) metrics_collector.register_histogram( "gateway_request_duration_seconds", "Request duration in seconds" ) logger.info("Metrics registered successfully") metrics_collector.start_metrics_server(8080) logger.info("API Gateway started successfully") @app.on_event("shutdown") async def shutdown_event(): """Application shutdown""" logger.info("Shutting down API Gateway") # Close shared Redis connection await close_redis() # Clean up service discovery # await service_discovery.cleanup() logger.info("API Gateway shutdown complete") @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "service": "api-gateway", "version": "1.0.0", "timestamp": time.time() } @app.get("/metrics") async def metrics(): """Prometheus metrics endpoint""" return Response( content=metrics_collector.get_metrics(), media_type="text/plain; version=0.0.4; charset=utf-8" ) # ================================================================ # SERVER-SENT EVENTS (SSE) HELPER FUNCTIONS # ================================================================ def _get_subscription_channels(tenant_id: str, channel_filters: list) -> list: """ Determine which Redis channels to subscribe to based on filters. Args: tenant_id: Tenant identifier channel_filters: List of channel patterns (e.g., ["inventory.alerts", "*.notifications"]) Returns: List of full channel names to subscribe to Examples: >>> _get_subscription_channels("abc", ["inventory.alerts"]) ["tenant:abc:inventory.alerts"] >>> _get_subscription_channels("abc", ["*.alerts"]) ["tenant:abc:inventory.alerts", "tenant:abc:production.alerts", ...] >>> _get_subscription_channels("abc", []) ["tenant:abc:inventory.alerts", "tenant:abc:inventory.notifications", ...] """ all_domains = ["inventory", "production", "supply_chain", "demand", "operations"] all_classes = ["alerts", "notifications"] channels = [] if not channel_filters: # Subscribe to ALL channels (backward compatible) for domain in all_domains: for event_class in all_classes: channels.append(f"tenant:{tenant_id}:{domain}.{event_class}") # Also subscribe to recommendations (tenant-wide) channels.append(f"tenant:{tenant_id}:recommendations") # Also subscribe to legacy channel for backward compatibility channels.append(f"alerts:{tenant_id}") return channels # Parse filters and expand wildcards for filter_pattern in channel_filters: if filter_pattern == "*.*": # All channels for domain in all_domains: for event_class in all_classes: channels.append(f"tenant:{tenant_id}:{domain}.{event_class}") channels.append(f"tenant:{tenant_id}:recommendations") elif filter_pattern.endswith(".*"): # Domain wildcard (e.g., "inventory.*") domain = filter_pattern.split(".")[0] for event_class in all_classes: channels.append(f"tenant:{tenant_id}:{domain}.{event_class}") elif filter_pattern.startswith("*."): # Class wildcard (e.g., "*.alerts") event_class = filter_pattern.split(".")[1] if event_class == "recommendations": channels.append(f"tenant:{tenant_id}:recommendations") else: for domain in all_domains: channels.append(f"tenant:{tenant_id}:{domain}.{event_class}") elif filter_pattern == "recommendations": # Recommendations channel channels.append(f"tenant:{tenant_id}:recommendations") else: # Specific channel (e.g., "inventory.alerts") channels.append(f"tenant:{tenant_id}:{filter_pattern}") return list(set(channels)) # Remove duplicates async def _load_initial_state(redis_client, tenant_id: str, channel_filters: list) -> list: """ Load initial state from Redis cache based on channel filters. Args: redis_client: Redis client tenant_id: Tenant identifier channel_filters: List of channel patterns Returns: List of initial events """ initial_events = [] try: if not channel_filters: # Load from legacy cache if no filters (backward compat) legacy_cache_key = f"active_alerts:{tenant_id}" cached_data = await redis_client.get(legacy_cache_key) if cached_data: return json.loads(cached_data) # Also try loading from new domain-specific caches all_domains = ["inventory", "production", "supply_chain", "demand", "operations"] all_classes = ["alerts", "notifications"] for domain in all_domains: for event_class in all_classes: cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s" cached_data = await redis_client.get(cache_key) if cached_data: events = json.loads(cached_data) initial_events.extend(events) # Load recommendations recommendations_cache_key = f"active_events:{tenant_id}:recommendations" cached_data = await redis_client.get(recommendations_cache_key) if cached_data: initial_events.extend(json.loads(cached_data)) return initial_events # Load based on specific filters for filter_pattern in channel_filters: # Extract domain and class from filter if "." in filter_pattern: parts = filter_pattern.split(".") domain = parts[0] if parts[0] != "*" else None event_class = parts[1] if len(parts) > 1 and parts[1] != "*" else None if domain and event_class: # Specific cache (e.g., "inventory.alerts") cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s" cached_data = await redis_client.get(cache_key) if cached_data: initial_events.extend(json.loads(cached_data)) elif domain and not event_class: # Domain wildcard (e.g., "inventory.*") for ec in ["alerts", "notifications"]: cache_key = f"active_events:{tenant_id}:{domain}.{ec}" cached_data = await redis_client.get(cache_key) if cached_data: initial_events.extend(json.loads(cached_data)) elif not domain and event_class: # Class wildcard (e.g., "*.alerts") all_domains = ["inventory", "production", "supply_chain", "demand", "operations"] for d in all_domains: cache_key = f"active_events:{tenant_id}:{d}.{event_class}s" cached_data = await redis_client.get(cache_key) if cached_data: initial_events.extend(json.loads(cached_data)) elif filter_pattern == "recommendations": cache_key = f"active_events:{tenant_id}:recommendations" cached_data = await redis_client.get(cache_key) if cached_data: initial_events.extend(json.loads(cached_data)) return initial_events except Exception as e: logger.error(f"Error loading initial state for tenant {tenant_id}: {e}") return [] def _determine_event_type(event_data: dict) -> str: """ Determine SSE event type from event data. Args: event_data: Event data dictionary Returns: SSE event type: 'alert', 'notification', or 'recommendation' """ # New event architecture uses 'event_class' if 'event_class' in event_data: return event_data['event_class'] # 'alert', 'notification', or 'recommendation' # Legacy format uses 'item_type' if 'item_type' in event_data: if event_data['item_type'] == 'recommendation': return 'recommendation' else: return 'alert' # Default to 'alert' for backward compatibility return 'alert' # ================================================================ # SERVER-SENT EVENTS (SSE) ENDPOINT # ================================================================ @app.get("/api/events") async def events_stream( request: Request, tenant_id: str, channels: str = None # Comma-separated channel filters (e.g., "inventory.alerts,production.notifications") ): """ Server-Sent Events stream for real-time notifications with multi-channel support. Authentication is handled by auth middleware via query param token. User context is available in request.state.user (injected by middleware). Query Parameters: tenant_id: Tenant identifier (required) channels: Comma-separated channel filters (optional) Examples: - "inventory.alerts,production.notifications" - Specific channels - "*.alerts" - All alert channels - "inventory.*" - All inventory events - None - All channels (default, backward compatible) New channel pattern: tenant:{tenant_id}:{domain}.{class} Examples: - tenant:abc:inventory.alerts - tenant:abc:production.notifications - tenant:abc:recommendations Legacy channel (backward compat): alerts:{tenant_id} """ global redis_client if not redis_client: raise HTTPException(status_code=503, detail="SSE service unavailable") # Extract user context from request state (set by auth middleware) user_context = request.state.user user_id = user_context.get('user_id') email = user_context.get('email') # Validate tenant_id parameter if not tenant_id: raise HTTPException(status_code=400, detail="tenant_id query parameter is required") # Parse channel filters channel_filters = [] if channels: channel_filters = [c.strip() for c in channels.split(',') if c.strip()] logger.info(f"SSE connection request for user {email}, tenant {tenant_id}, channels: {channel_filters or 'all'}") async def event_generator(): """Generate server-sent events from Redis pub/sub with multi-channel support""" pubsub = None try: # Create pubsub connection with resource monitoring pubsub = redis_client.pubsub() logger.debug(f"Created Redis pubsub connection for tenant: {tenant_id}") # Monitor connection count try: connection_info = await redis_client.info('clients') connected_clients = connection_info.get('connected_clients', 'unknown') logger.debug(f"Redis connected clients: {connected_clients}") except Exception: # Don't fail if we can't get connection info pass # Determine which channels to subscribe to subscription_channels = _get_subscription_channels(tenant_id, channel_filters) # Subscribe to all determined channels if subscription_channels: await pubsub.subscribe(*subscription_channels) logger.info(f"Subscribed to {len(subscription_channels)} channels for tenant {tenant_id}") else: # Fallback to legacy channel if no channels specified legacy_channel = f"alerts:{tenant_id}" await pubsub.subscribe(legacy_channel) logger.info(f"Subscribed to legacy channel: {legacy_channel}") # Send initial connection event yield f"event: connection\n" yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'channels': subscription_channels or ['all'], 'timestamp': time.time()})}\n\n" # Fetch and send initial state from cache (domain-specific or legacy) initial_events = await _load_initial_state(redis_client, tenant_id, channel_filters) if initial_events: logger.info(f"Sending {len(initial_events)} initial events to tenant {tenant_id}") yield f"event: initial_state\n" yield f"data: {json.dumps(initial_events)}\n\n" else: # Send empty initial state for compatibility yield f"event: initial_state\n" yield f"data: {json.dumps([])}\n\n" heartbeat_counter = 0 while True: # Check if client has disconnected if await request.is_disconnected(): logger.info(f"SSE client disconnected for tenant: {tenant_id}") break try: # Get message from Redis with timeout message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0) if message and message['type'] == 'message': # Forward the event from Redis event_data = json.loads(message['data']) # Determine event type for SSE event_type = _determine_event_type(event_data) # Add channel metadata for frontend routing event_data['_channel'] = message['channel'].decode('utf-8') if isinstance(message['channel'], bytes) else message['channel'] yield f"event: {event_type}\n" yield f"data: {json.dumps(event_data)}\n\n" logger.debug(f"SSE event sent to tenant {tenant_id}: {event_type} - {event_data.get('title')}") except asyncio.TimeoutError: # Send heartbeat every 10 timeouts (100 seconds) heartbeat_counter += 1 if heartbeat_counter >= 10: yield f"event: heartbeat\n" yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n" heartbeat_counter = 0 except asyncio.CancelledError: logger.info(f"SSE connection cancelled for tenant: {tenant_id}") except Exception as e: logger.error(f"SSE error for tenant {tenant_id}: {e}", exc_info=True) finally: try: if pubsub: try: # Unsubscribe from all channels await pubsub.unsubscribe() logger.debug(f"Unsubscribed from Redis channels for tenant: {tenant_id}") except Exception as unsubscribe_error: logger.error(f"Failed to unsubscribe Redis pubsub for tenant {tenant_id}: {unsubscribe_error}") try: # Close pubsub connection await pubsub.close() logger.debug(f"Closed Redis pubsub connection for tenant: {tenant_id}") except Exception as close_error: logger.error(f"Failed to close Redis pubsub for tenant {tenant_id}: {close_error}") logger.info(f"SSE connection closed for tenant: {tenant_id}") except Exception as finally_error: logger.error(f"Error in SSE cleanup for tenant {tenant_id}: {finally_error}") return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "Cache-Control", } ) # ================================================================ # WEBSOCKET ROUTING FOR TRAINING SERVICE # ================================================================ @app.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live") async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str): """ Simple WebSocket proxy with token verification only. Validates the token and forwards the connection to the training service. """ # Get token from query params token = websocket.query_params.get("token") if not token: logger.warning("WebSocket proxy rejected - missing token", job_id=job_id, tenant_id=tenant_id) await websocket.accept() await websocket.close(code=1008, reason="Authentication token required") return # Verify token from shared.auth.jwt_handler import JWTHandler jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM) try: payload = jwt_handler.verify_token(token) if not payload or not payload.get('user_id'): logger.warning("WebSocket proxy rejected - invalid token", job_id=job_id, tenant_id=tenant_id) await websocket.accept() await websocket.close(code=1008, reason="Invalid token") return logger.info("WebSocket proxy - token verified", user_id=payload.get('user_id'), tenant_id=tenant_id, job_id=job_id) except Exception as e: logger.warning("WebSocket proxy - token verification failed", job_id=job_id, error=str(e)) await websocket.accept() await websocket.close(code=1008, reason="Token verification failed") return # Accept the connection await websocket.accept() # Build WebSocket URL to training service training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/') training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://') training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}" logger.info("Gateway proxying WebSocket to training service", job_id=job_id, training_ws_url=training_ws_url.replace(token, '***')) training_ws = None try: # Connect to training service WebSocket import websockets from websockets.protocol import State training_ws = await websockets.connect( training_ws_url, ping_interval=120, # Send ping every 2 minutes (tolerates long training operations) ping_timeout=60, # Wait up to 1 minute for pong (graceful timeout) close_timeout=60, # Increase close timeout for graceful shutdown open_timeout=30 ) logger.info("Gateway connected to training service WebSocket", job_id=job_id) async def forward_frontend_to_training(): """Forward messages from frontend to training service""" try: while training_ws and training_ws.state == State.OPEN: data = await websocket.receive() if data.get("type") == "websocket.receive": if "text" in data: await training_ws.send(data["text"]) elif "bytes" in data: await training_ws.send(data["bytes"]) elif data.get("type") == "websocket.disconnect": break except Exception as e: logger.debug("Frontend to training forward ended", error=str(e)) async def forward_training_to_frontend(): """Forward messages from training service to frontend""" message_count = 0 try: while training_ws and training_ws.state == State.OPEN: message = await training_ws.recv() await websocket.send_text(message) message_count += 1 # Log every 10th message to track connectivity if message_count % 10 == 0: logger.debug("WebSocket proxy active", job_id=job_id, messages_forwarded=message_count) except Exception as e: logger.info("Training to frontend forward ended", job_id=job_id, messages_forwarded=message_count, error=str(e)) # Run both forwarding tasks concurrently await asyncio.gather( forward_frontend_to_training(), forward_training_to_frontend(), return_exceptions=True ) except Exception as e: logger.error("WebSocket proxy error", job_id=job_id, error=str(e)) finally: # Cleanup if training_ws and training_ws.state == State.OPEN: try: await training_ws.close() except: pass try: if not websocket.client_state.name == 'DISCONNECTED': await websocket.close(code=1000, reason="Proxy closed") except: pass logger.info("WebSocket proxy connection closed", job_id=job_id) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)