Files
bakery-ia/gateway/app/main.py

454 lines
19 KiB
Python
Raw Normal View History

"""
API Gateway - Central entry point for all microservices
Handles routing, authentication, rate limiting, and cross-cutting concerns
"""
import asyncio
import json
2025-07-18 14:41:39 +02:00
import structlog
2025-08-14 13:26:59 +02:00
from fastapi import FastAPI, Request, HTTPException, Depends, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
import httpx
import time
import redis.asyncio as aioredis
from typing import Dict, Any
from app.core.config import settings
from app.core.service_discovery import ServiceDiscovery
2025-07-17 19:54:04 +02:00
from app.middleware.auth import AuthMiddleware
from app.middleware.logging import LoggingMiddleware
from app.middleware.rate_limit import RateLimitMiddleware
2025-09-21 13:27:50 +02:00
from app.middleware.subscription import SubscriptionMiddleware
2025-10-03 14:09:34 +02:00
from app.middleware.demo_middleware import DemoMiddleware
from app.routes import auth, tenant, notification, nominatim, user, subscription, demo
from shared.monitoring.logging import setup_logging
from shared.monitoring.metrics import MetricsCollector
# Setup logging
setup_logging("gateway", settings.LOG_LEVEL)
2025-07-18 14:41:39 +02:00
logger = structlog.get_logger()
# Create FastAPI app
app = FastAPI(
title="Bakery Forecasting API Gateway",
description="Central API Gateway for bakery forecasting microservices",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Initialize metrics collector
metrics_collector = MetricsCollector("gateway")
# Service discovery
service_discovery = ServiceDiscovery()
# Redis client for SSE streaming
redis_client = None
2025-07-17 19:54:04 +02:00
# CORS middleware - Add first
app.add_middleware(
CORSMiddleware,
2025-07-17 19:46:41 +02:00
allow_origins=settings.CORS_ORIGINS_LIST,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
2025-10-03 14:09:34 +02:00
# Custom middleware - Add in REVERSE order (last added = first executed)
# Execution order: DemoMiddleware -> AuthMiddleware -> SubscriptionMiddleware -> RateLimitMiddleware -> LoggingMiddleware
app.add_middleware(LoggingMiddleware) # Executes 5th (outermost)
app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 4th
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 3rd
app.add_middleware(AuthMiddleware) # Executes 2nd - Checks for demo context
app.add_middleware(DemoMiddleware) # Executes 1st (innermost) - Sets demo user context FIRST
# Include routers
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
2025-07-26 19:15:18 +02:00
app.include_router(user.router, prefix="/api/v1/users", tags=["users"])
app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"])
2025-09-21 13:27:50 +02:00
app.include_router(subscription.router, prefix="/api/v1", tags=["subscriptions"])
app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"])
2025-07-22 17:01:12 +02:00
app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"])
2025-10-03 14:09:34 +02:00
app.include_router(demo.router, prefix="/api/v1", tags=["demo"])
@app.on_event("startup")
async def startup_event():
"""Application startup"""
global redis_client
logger.info("Starting API Gateway")
# Connect to Redis for SSE streaming
try:
redis_client = aioredis.from_url(settings.REDIS_URL)
logger.info("Connected to Redis for SSE streaming")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
2025-08-02 21:56:25 +02:00
2025-07-26 21:10:54 +02:00
metrics_collector.register_counter(
"gateway_auth_requests_total",
2025-08-02 21:56:25 +02:00
"Total authentication requests"
2025-07-26 21:10:54 +02:00
)
metrics_collector.register_counter(
"gateway_auth_responses_total",
2025-08-02 21:56:25 +02:00
"Total authentication responses"
)
metrics_collector.register_counter(
"gateway_auth_errors_total",
"Total authentication errors"
2025-07-26 21:10:54 +02:00
)
2025-08-02 21:56:25 +02:00
2025-07-26 21:10:54 +02:00
metrics_collector.register_histogram(
"gateway_request_duration_seconds",
2025-08-02 21:56:25 +02:00
"Request duration in seconds"
2025-07-26 21:10:54 +02:00
)
2025-08-02 21:56:25 +02:00
logger.info("Metrics registered successfully")
2025-07-26 21:10:54 +02:00
metrics_collector.start_metrics_server(8080)
logger.info("API Gateway started successfully")
@app.on_event("shutdown")
async def shutdown_event():
"""Application shutdown"""
global redis_client
logger.info("Shutting down API Gateway")
# Close Redis connection
if redis_client:
await redis_client.close()
# Clean up service discovery
2025-07-18 12:57:13 +02:00
# await service_discovery.cleanup()
logger.info("API Gateway shutdown complete")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
2025-07-17 19:54:04 +02:00
"service": "api-gateway",
"version": "1.0.0",
"timestamp": time.time()
}
@app.get("/metrics")
2025-07-17 19:54:04 +02:00
async def metrics():
"""Metrics endpoint for monitoring"""
return {"metrics": "enabled"}
# ================================================================
# SERVER-SENT EVENTS (SSE) ENDPOINT
# ================================================================
@app.get("/api/events")
async def events_stream(request: Request, tenant_id: str):
"""
Server-Sent Events stream for real-time notifications.
Authentication is handled by auth middleware via query param token.
User context is available in request.state.user (injected by middleware).
Tenant ID is provided by the frontend as a query parameter.
"""
global redis_client
if not redis_client:
raise HTTPException(status_code=503, detail="SSE service unavailable")
# Extract user context from request state (set by auth middleware)
user_context = request.state.user
user_id = user_context.get('user_id')
email = user_context.get('email')
# Validate tenant_id parameter
if not tenant_id:
raise HTTPException(status_code=400, detail="tenant_id query parameter is required")
logger.info(f"SSE connection request for user {email}, tenant {tenant_id}")
logger.info(f"SSE connection established for tenant: {tenant_id}")
async def event_generator():
"""Generate server-sent events from Redis pub/sub"""
pubsub = None
try:
# Subscribe to tenant-specific alert channel
pubsub = redis_client.pubsub()
channel_name = f"alerts:{tenant_id}"
await pubsub.subscribe(channel_name)
# Send initial connection event
yield f"event: connection\n"
yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'timestamp': time.time()})}\n\n"
heartbeat_counter = 0
while True:
# Check if client has disconnected
if await request.is_disconnected():
logger.info(f"SSE client disconnected for tenant: {tenant_id}")
break
try:
# Get message from Redis with timeout
message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0)
if message and message['type'] == 'message':
# Forward the alert/notification from Redis
alert_data = json.loads(message['data'])
# Determine event type based on alert data
event_type = "notification"
if alert_data.get('item_type') == 'alert':
if alert_data.get('severity') in ['high', 'urgent']:
event_type = "inventory_alert"
else:
event_type = "notification"
elif alert_data.get('item_type') == 'recommendation':
event_type = "notification"
yield f"event: {event_type}\n"
yield f"data: {json.dumps(alert_data)}\n\n"
logger.debug(f"SSE message sent to tenant {tenant_id}: {alert_data.get('title')}")
except asyncio.TimeoutError:
# Send heartbeat every 10 timeouts (100 seconds)
heartbeat_counter += 1
if heartbeat_counter >= 10:
yield f"event: heartbeat\n"
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n"
heartbeat_counter = 0
except asyncio.CancelledError:
logger.info(f"SSE connection cancelled for tenant: {tenant_id}")
except Exception as e:
logger.error(f"SSE error for tenant {tenant_id}: {e}")
finally:
if pubsub:
await pubsub.unsubscribe()
await pubsub.close()
logger.info(f"SSE connection closed for tenant: {tenant_id}")
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control",
}
)
2025-08-14 13:26:59 +02:00
# ================================================================
# WEBSOCKET ROUTING FOR TRAINING SERVICE
# ================================================================
@app.websocket("/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live")
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
"""WebSocket proxy that forwards connections directly to training service with enhanced token validation"""
2025-08-14 13:26:59 +02:00
await websocket.accept()
2025-08-14 13:26:59 +02:00
# Get token from query params
token = websocket.query_params.get("token")
if not token:
2025-08-15 17:53:59 +02:00
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
2025-08-14 13:26:59 +02:00
await websocket.close(code=1008, reason="Authentication token required")
return
# Validate token using auth middleware
from app.middleware.auth import jwt_handler
try:
payload = jwt_handler.verify_token(token)
if not payload:
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
await websocket.close(code=1008, reason="Invalid authentication token")
return
# Check token expiration
import time
if payload.get('exp', 0) < time.time():
logger.warning(f"WebSocket connection rejected - expired token for job {job_id}")
await websocket.close(code=1008, reason="Token expired")
return
logger.info(f"WebSocket token validated for user {payload.get('email', 'unknown')}")
except Exception as e:
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
await websocket.close(code=1008, reason="Token validation failed")
return
2025-08-15 17:53:59 +02:00
logger.info(f"Proxying WebSocket connection to training service for job {job_id}, tenant {tenant_id}")
2025-08-15 17:53:59 +02:00
# Build WebSocket URL to training service
2025-08-14 13:26:59 +02:00
training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/')
2025-08-15 17:53:59 +02:00
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
2025-10-06 15:27:01 +02:00
training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
training_ws = None
heartbeat_task = None
2025-08-14 13:26:59 +02:00
try:
# Connect to training service WebSocket with proper timeout configuration
2025-08-15 17:53:59 +02:00
import websockets
# Configure timeouts to coordinate with frontend (30s heartbeat) and training service
# DISABLE gateway-level ping to avoid dual-ping conflicts - let frontend handle ping/pong
training_ws = await websockets.connect(
training_ws_url,
ping_interval=None, # DISABLED: Let frontend handle ping/pong via message forwarding
ping_timeout=None, # DISABLED: No independent ping mechanism
close_timeout=15, # Reasonable close timeout
max_size=2**20, # 1MB max message size
max_queue=32 # Max queued messages
)
logger.info(f"Connected to training service WebSocket for job {job_id} with gateway ping DISABLED (frontend handles ping/pong)")
# Track connection state properly due to FastAPI WebSocket state propagation bug
connection_alive = True
last_activity = asyncio.get_event_loop().time()
async def check_connection_health():
"""Monitor connection health based on activity timestamps only - no WebSocket interference"""
nonlocal connection_alive, last_activity
while connection_alive:
2025-08-15 17:53:59 +02:00
try:
await asyncio.sleep(30) # Check every 30 seconds (aligned with frontend heartbeat)
current_time = asyncio.get_event_loop().time()
# Check if we haven't received any activity for too long
# Frontend sends ping every 30s, so 90s = 3 missed pings before considering dead
if current_time - last_activity > 90:
logger.warning(f"No frontend activity for 90s on job {job_id} - connection may be dead")
# Don't forcibly close - let the forwarding loops handle actual connection issues
# This is just monitoring/logging now
else:
logger.debug(f"Connection health OK for job {job_id} - last activity {int(current_time - last_activity)}s ago")
2025-08-15 17:53:59 +02:00
except Exception as e:
logger.error(f"Connection health monitoring error for job {job_id}: {e}")
break
async def forward_to_training():
"""Forward messages from frontend to training service with proper error handling"""
nonlocal connection_alive, last_activity
try:
while connection_alive and training_ws and training_ws.open:
try:
# Use longer timeout to avoid conflicts with frontend 30s heartbeat
# Frontend sends ping every 30s, so we need to allow for some latency
2025-10-06 15:27:01 +02:00
data = await asyncio.wait_for(websocket.receive(), timeout=45.0)
last_activity = asyncio.get_event_loop().time()
2025-10-06 15:27:01 +02:00
# Handle different message types
if data.get("type") == "websocket.receive":
if "text" in data:
message = data["text"]
# Forward text messages to training service
await training_ws.send(message)
logger.debug(f"Forwarded message to training service for job {job_id}: {message[:100]}...")
elif "bytes" in data:
# Forward binary messages if needed
await training_ws.send(data["bytes"])
# Ping/pong frames are automatically handled by Starlette/FastAPI
except asyncio.TimeoutError:
# No message received in 45 seconds, continue loop
# This allows for frontend 30s heartbeat + network latency + processing time
continue
except Exception as e:
logger.error(f"Error receiving from frontend for job {job_id}: {e}")
connection_alive = False
break
except Exception as e:
logger.error(f"Error in forward_to_training for job {job_id}: {e}")
connection_alive = False
async def forward_to_frontend():
"""Forward messages from training service to frontend with proper error handling"""
nonlocal connection_alive, last_activity
try:
while connection_alive and training_ws and training_ws.open:
try:
# Use coordinated timeout - training service expects messages every 60s
# This should be longer than training service timeout to avoid premature closure
message = await asyncio.wait_for(training_ws.recv(), timeout=75.0)
last_activity = asyncio.get_event_loop().time()
# Forward the message to frontend
2025-08-15 17:53:59 +02:00
await websocket.send_text(message)
logger.debug(f"Forwarded message to frontend for job {job_id}: {message[:100]}...")
except asyncio.TimeoutError:
# No message received in 75 seconds, continue loop
# Training service sends heartbeats, so this indicates potential issues
continue
except Exception as e:
logger.error(f"Error receiving from training service for job {job_id}: {e}")
connection_alive = False
break
except Exception as e:
logger.error(f"Error in forward_to_frontend for job {job_id}: {e}")
connection_alive = False
# Start connection health monitoring
heartbeat_task = asyncio.create_task(check_connection_health())
# Run both forwarding tasks concurrently with proper error handling
try:
2025-08-15 17:53:59 +02:00
await asyncio.gather(
forward_to_training(),
forward_to_frontend(),
return_exceptions=True
)
except Exception as e:
logger.error(f"Error in WebSocket forwarding tasks for job {job_id}: {e}")
finally:
connection_alive = False
except websockets.exceptions.ConnectionClosedError as e:
logger.warning(f"Training service WebSocket connection closed for job {job_id}: {e}")
except websockets.exceptions.WebSocketException as e:
logger.error(f"WebSocket exception for job {job_id}: {e}")
2025-08-14 13:26:59 +02:00
except Exception as e:
2025-08-15 17:53:59 +02:00
logger.error(f"WebSocket proxy error for job {job_id}: {e}")
finally:
# Cleanup
if heartbeat_task and not heartbeat_task.done():
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
if training_ws and not training_ws.closed:
try:
await training_ws.close()
except Exception as e:
logger.warning(f"Error closing training service WebSocket for job {job_id}: {e}")
try:
if not websocket.client_state.name == 'DISCONNECTED':
await websocket.close(code=1000, reason="Proxy connection closed")
except Exception as e:
logger.warning(f"Error closing frontend WebSocket for job {job_id}: {e}")
logger.info(f"WebSocket proxy cleanup completed for job {job_id}")
2025-08-14 13:26:59 +02:00
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)