Files
bakery-ia/gateway/app/main.py

391 lines
14 KiB
Python
Raw Normal View History

"""
API Gateway - Central entry point for all microservices
Handles routing, authentication, rate limiting, and cross-cutting concerns
"""
import asyncio
import json
2025-07-18 14:41:39 +02:00
import structlog
2025-08-14 13:26:59 +02:00
from fastapi import FastAPI, Request, HTTPException, Depends, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
import httpx
import time
import redis.asyncio as aioredis
from typing import Dict, Any
from app.core.config import settings
from app.core.service_discovery import ServiceDiscovery
2025-07-17 19:54:04 +02:00
from app.middleware.auth import AuthMiddleware
from app.middleware.logging import LoggingMiddleware
from app.middleware.rate_limit import RateLimitMiddleware
2025-09-21 13:27:50 +02:00
from app.middleware.subscription import SubscriptionMiddleware
2025-10-03 14:09:34 +02:00
from app.middleware.demo_middleware import DemoMiddleware
2025-10-07 07:15:07 +02:00
from app.routes import auth, tenant, notification, nominatim, user, subscription, demo, pos
from shared.monitoring.logging import setup_logging
from shared.monitoring.metrics import MetricsCollector
# Setup logging
setup_logging("gateway", settings.LOG_LEVEL)
2025-07-18 14:41:39 +02:00
logger = structlog.get_logger()
# Create FastAPI app
app = FastAPI(
title="Bakery Forecasting API Gateway",
description="Central API Gateway for bakery forecasting microservices",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Initialize metrics collector
metrics_collector = MetricsCollector("gateway")
# Service discovery
service_discovery = ServiceDiscovery()
# Redis client for SSE streaming
redis_client = None
2025-07-17 19:54:04 +02:00
# CORS middleware - Add first
app.add_middleware(
CORSMiddleware,
2025-07-17 19:46:41 +02:00
allow_origins=settings.CORS_ORIGINS_LIST,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
2025-10-03 14:09:34 +02:00
# Custom middleware - Add in REVERSE order (last added = first executed)
# Execution order: DemoMiddleware -> AuthMiddleware -> SubscriptionMiddleware -> RateLimitMiddleware -> LoggingMiddleware
app.add_middleware(LoggingMiddleware) # Executes 5th (outermost)
app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 4th
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 3rd
app.add_middleware(AuthMiddleware) # Executes 2nd - Checks for demo context
app.add_middleware(DemoMiddleware) # Executes 1st (innermost) - Sets demo user context FIRST
# Include routers
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
2025-07-26 19:15:18 +02:00
app.include_router(user.router, prefix="/api/v1/users", tags=["users"])
app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"])
2025-09-21 13:27:50 +02:00
app.include_router(subscription.router, prefix="/api/v1", tags=["subscriptions"])
app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"])
2025-07-22 17:01:12 +02:00
app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"])
2025-10-07 07:15:07 +02:00
app.include_router(pos.router, prefix="/api/v1/pos", tags=["pos"])
2025-10-03 14:09:34 +02:00
app.include_router(demo.router, prefix="/api/v1", tags=["demo"])
@app.on_event("startup")
async def startup_event():
"""Application startup"""
global redis_client
logger.info("Starting API Gateway")
# Connect to Redis for SSE streaming
try:
redis_client = aioredis.from_url(settings.REDIS_URL)
logger.info("Connected to Redis for SSE streaming")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
2025-08-02 21:56:25 +02:00
2025-07-26 21:10:54 +02:00
metrics_collector.register_counter(
"gateway_auth_requests_total",
2025-08-02 21:56:25 +02:00
"Total authentication requests"
2025-07-26 21:10:54 +02:00
)
metrics_collector.register_counter(
"gateway_auth_responses_total",
2025-08-02 21:56:25 +02:00
"Total authentication responses"
)
metrics_collector.register_counter(
"gateway_auth_errors_total",
"Total authentication errors"
2025-07-26 21:10:54 +02:00
)
2025-08-02 21:56:25 +02:00
2025-07-26 21:10:54 +02:00
metrics_collector.register_histogram(
"gateway_request_duration_seconds",
2025-08-02 21:56:25 +02:00
"Request duration in seconds"
2025-07-26 21:10:54 +02:00
)
2025-08-02 21:56:25 +02:00
logger.info("Metrics registered successfully")
2025-07-26 21:10:54 +02:00
metrics_collector.start_metrics_server(8080)
logger.info("API Gateway started successfully")
@app.on_event("shutdown")
async def shutdown_event():
"""Application shutdown"""
global redis_client
logger.info("Shutting down API Gateway")
# Close Redis connection
if redis_client:
await redis_client.close()
# Clean up service discovery
2025-07-18 12:57:13 +02:00
# await service_discovery.cleanup()
logger.info("API Gateway shutdown complete")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
2025-07-17 19:54:04 +02:00
"service": "api-gateway",
"version": "1.0.0",
"timestamp": time.time()
}
@app.get("/metrics")
2025-07-17 19:54:04 +02:00
async def metrics():
"""Metrics endpoint for monitoring"""
return {"metrics": "enabled"}
# ================================================================
# SERVER-SENT EVENTS (SSE) ENDPOINT
# ================================================================
@app.get("/api/events")
async def events_stream(request: Request, tenant_id: str):
"""
Server-Sent Events stream for real-time notifications.
Authentication is handled by auth middleware via query param token.
User context is available in request.state.user (injected by middleware).
Tenant ID is provided by the frontend as a query parameter.
"""
global redis_client
if not redis_client:
raise HTTPException(status_code=503, detail="SSE service unavailable")
# Extract user context from request state (set by auth middleware)
user_context = request.state.user
user_id = user_context.get('user_id')
email = user_context.get('email')
# Validate tenant_id parameter
if not tenant_id:
raise HTTPException(status_code=400, detail="tenant_id query parameter is required")
logger.info(f"SSE connection request for user {email}, tenant {tenant_id}")
logger.info(f"SSE connection established for tenant: {tenant_id}")
async def event_generator():
"""Generate server-sent events from Redis pub/sub"""
pubsub = None
try:
# Subscribe to tenant-specific alert channel
pubsub = redis_client.pubsub()
channel_name = f"alerts:{tenant_id}"
await pubsub.subscribe(channel_name)
# Send initial connection event
yield f"event: connection\n"
yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'timestamp': time.time()})}\n\n"
heartbeat_counter = 0
while True:
# Check if client has disconnected
if await request.is_disconnected():
logger.info(f"SSE client disconnected for tenant: {tenant_id}")
break
try:
# Get message from Redis with timeout
message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0)
if message and message['type'] == 'message':
# Forward the alert/notification from Redis
alert_data = json.loads(message['data'])
# Determine event type based on alert data
event_type = "notification"
if alert_data.get('item_type') == 'alert':
if alert_data.get('severity') in ['high', 'urgent']:
event_type = "inventory_alert"
else:
event_type = "notification"
elif alert_data.get('item_type') == 'recommendation':
event_type = "notification"
yield f"event: {event_type}\n"
yield f"data: {json.dumps(alert_data)}\n\n"
logger.debug(f"SSE message sent to tenant {tenant_id}: {alert_data.get('title')}")
except asyncio.TimeoutError:
# Send heartbeat every 10 timeouts (100 seconds)
heartbeat_counter += 1
if heartbeat_counter >= 10:
yield f"event: heartbeat\n"
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n"
heartbeat_counter = 0
except asyncio.CancelledError:
logger.info(f"SSE connection cancelled for tenant: {tenant_id}")
except Exception as e:
logger.error(f"SSE error for tenant {tenant_id}: {e}")
finally:
if pubsub:
await pubsub.unsubscribe()
await pubsub.close()
logger.info(f"SSE connection closed for tenant: {tenant_id}")
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control",
}
)
2025-08-14 13:26:59 +02:00
# ================================================================
# WEBSOCKET ROUTING FOR TRAINING SERVICE
# ================================================================
2025-10-07 07:15:07 +02:00
@app.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
2025-08-14 13:26:59 +02:00
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
2025-10-07 07:15:07 +02:00
"""
Simple WebSocket proxy with token verification only.
Validates the token and forwards the connection to the training service.
2025-10-07 07:15:07 +02:00
"""
# Get token from query params
2025-08-14 13:26:59 +02:00
token = websocket.query_params.get("token")
if not token:
logger.warning("WebSocket proxy rejected - missing token",
job_id=job_id,
tenant_id=tenant_id)
2025-10-07 07:15:07 +02:00
await websocket.accept()
2025-08-14 13:26:59 +02:00
await websocket.close(code=1008, reason="Authentication token required")
return
# Verify token
from shared.auth.jwt_handler import JWTHandler
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload or not payload.get('user_id'):
logger.warning("WebSocket proxy rejected - invalid token",
job_id=job_id,
tenant_id=tenant_id)
await websocket.accept()
await websocket.close(code=1008, reason="Invalid token")
return
logger.info("WebSocket proxy - token verified",
user_id=payload.get('user_id'),
tenant_id=tenant_id,
job_id=job_id)
except Exception as e:
logger.warning("WebSocket proxy - token verification failed",
job_id=job_id,
error=str(e))
await websocket.accept()
await websocket.close(code=1008, reason="Token verification failed")
return
# Accept the connection
await websocket.accept()
# Build WebSocket URL to training service
2025-08-14 13:26:59 +02:00
training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/')
2025-08-15 17:53:59 +02:00
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
2025-10-06 15:27:01 +02:00
training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
logger.info("Gateway proxying WebSocket to training service",
job_id=job_id,
training_ws_url=training_ws_url.replace(token, '***'))
training_ws = None
2025-08-14 13:26:59 +02:00
try:
2025-10-07 07:15:07 +02:00
# Connect to training service WebSocket
2025-08-15 17:53:59 +02:00
import websockets
training_ws = await websockets.connect(
training_ws_url,
ping_interval=120, # Send ping every 2 minutes (tolerates long training operations)
ping_timeout=60, # Wait up to 1 minute for pong (graceful timeout)
close_timeout=60, # Increase close timeout for graceful shutdown
open_timeout=30
)
logger.info("Gateway connected to training service WebSocket", job_id=job_id)
async def forward_frontend_to_training():
2025-10-07 07:15:07 +02:00
"""Forward messages from frontend to training service"""
try:
2025-10-07 07:15:07 +02:00
while training_ws and training_ws.open:
data = await websocket.receive()
if data.get("type") == "websocket.receive":
if "text" in data:
await training_ws.send(data["text"])
elif "bytes" in data:
await training_ws.send(data["bytes"])
elif data.get("type") == "websocket.disconnect":
break
except Exception as e:
logger.debug("Frontend to training forward ended", error=str(e))
async def forward_training_to_frontend():
2025-10-07 07:15:07 +02:00
"""Forward messages from training service to frontend"""
message_count = 0
try:
2025-10-07 07:15:07 +02:00
while training_ws and training_ws.open:
message = await training_ws.recv()
await websocket.send_text(message)
message_count += 1
# Log every 10th message to track connectivity
if message_count % 10 == 0:
logger.debug("WebSocket proxy active",
job_id=job_id,
messages_forwarded=message_count)
except Exception as e:
logger.info("Training to frontend forward ended",
job_id=job_id,
messages_forwarded=message_count,
error=str(e))
2025-10-07 07:15:07 +02:00
# Run both forwarding tasks concurrently
await asyncio.gather(
forward_frontend_to_training(),
forward_training_to_frontend(),
2025-10-07 07:15:07 +02:00
return_exceptions=True
)
2025-08-14 13:26:59 +02:00
except Exception as e:
logger.error("WebSocket proxy error", job_id=job_id, error=str(e))
2025-08-15 17:53:59 +02:00
finally:
# Cleanup
if training_ws and not training_ws.closed:
try:
await training_ws.close()
except:
pass
try:
if not websocket.client_state.name == 'DISCONNECTED':
2025-10-07 07:15:07 +02:00
await websocket.close(code=1000, reason="Proxy closed")
except:
pass
logger.info("WebSocket proxy connection closed", job_id=job_id)
2025-08-14 13:26:59 +02:00
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)