Files
bakery-ia/gateway/app/main.py
2025-12-14 19:05:37 +01:00

679 lines
27 KiB
Python

"""
API Gateway - Central entry point for all microservices
Handles routing, authentication, rate limiting, and cross-cutting concerns
"""
import asyncio
import json
import structlog
import resource
import os
from fastapi import FastAPI, Request, HTTPException, Depends, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
import httpx
import time
from shared.redis_utils import initialize_redis, close_redis, get_redis_client
from typing import Dict, Any
from app.core.config import settings
from app.middleware.request_id import RequestIDMiddleware
from app.middleware.auth import AuthMiddleware
from app.middleware.logging import LoggingMiddleware
from app.middleware.rate_limit import RateLimitMiddleware
from app.middleware.subscription import SubscriptionMiddleware
from app.middleware.demo_middleware import DemoMiddleware
from app.middleware.read_only_mode import ReadOnlyModeMiddleware
from app.routes import auth, tenant, notification, nominatim, subscription, demo, pos, geocoding, poi_context
from shared.monitoring.logging import setup_logging
from shared.monitoring.metrics import MetricsCollector
# Setup logging
setup_logging("gateway", settings.LOG_LEVEL)
logger = structlog.get_logger()
# Check file descriptor limits and warn if too low
try:
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
if soft_limit < 1024:
logger.warning(f"Low file descriptor limit detected: {soft_limit}. Gateway may experience 'too many open files' errors.")
logger.warning(f"Recommended: Increase limit with 'ulimit -n 4096' or higher for production.")
if soft_limit < 256:
logger.error(f"Critical: File descriptor limit ({soft_limit}) is too low for gateway operation!")
else:
logger.info(f"File descriptor limit: {soft_limit} (sufficient)")
except Exception as e:
logger.debug(f"Could not check file descriptor limits: {e}")
# Check and log current working directory and permissions
try:
cwd = os.getcwd()
logger.info(f"Current working directory: {cwd}")
# Check if we can write to common log locations
test_locations = ["/var/log", "./logs", "."]
for location in test_locations:
try:
test_file = os.path.join(location, ".gateway_permission_test")
with open(test_file, 'w') as f:
f.write("test")
os.remove(test_file)
logger.info(f"Write permission confirmed for: {location}")
except Exception as e:
logger.warning(f"Cannot write to {location}: {e}")
except Exception as e:
logger.debug(f"Could not check directory permissions: {e}")
# Create FastAPI app
app = FastAPI(
title="Bakery Forecasting API Gateway",
description="Central API Gateway for bakery forecasting microservices",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc",
redirect_slashes=False # Disable automatic trailing slash redirects
)
# Initialize metrics collector
metrics_collector = MetricsCollector("gateway")
# Redis client for SSE streaming
redis_client = None
# CORS middleware - Add first
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS_LIST,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Custom middleware - Add in REVERSE order (last added = first executed)
# Execution order: RequestIDMiddleware -> DemoMiddleware -> AuthMiddleware -> ReadOnlyModeMiddleware -> SubscriptionMiddleware -> RateLimitMiddleware -> LoggingMiddleware
app.add_middleware(LoggingMiddleware) # Executes 7th (outermost)
app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 6th
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 5th
app.add_middleware(ReadOnlyModeMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 4th - Enforce read-only mode
app.add_middleware(AuthMiddleware) # Executes 3rd - Checks for demo context
app.add_middleware(DemoMiddleware) # Executes 2nd - Sets demo user context
app.add_middleware(RequestIDMiddleware) # Executes 1st (innermost) - Generates request ID for tracing
# Include routers
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"])
app.include_router(subscription.router, prefix="/api/v1", tags=["subscriptions"])
app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"])
app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"])
app.include_router(geocoding.router, prefix="/api/v1/geocoding", tags=["geocoding"])
# app.include_router(poi_context.router, prefix="/api/v1/poi-context", tags=["poi-context"]) # Removed to implement tenant-based architecture
app.include_router(pos.router, prefix="/api/v1/pos", tags=["pos"])
app.include_router(demo.router, prefix="/api/v1", tags=["demo"])
@app.on_event("startup")
async def startup_event():
"""Application startup"""
global redis_client
logger.info("Starting API Gateway")
# Initialize shared Redis connection
try:
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
redis_client = await get_redis_client()
logger.info("Connected to Redis for SSE streaming")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
metrics_collector.register_counter(
"gateway_auth_requests_total",
"Total authentication requests"
)
metrics_collector.register_counter(
"gateway_auth_responses_total",
"Total authentication responses"
)
metrics_collector.register_counter(
"gateway_auth_errors_total",
"Total authentication errors"
)
metrics_collector.register_histogram(
"gateway_request_duration_seconds",
"Request duration in seconds"
)
logger.info("Metrics registered successfully")
metrics_collector.start_metrics_server(8080)
logger.info("API Gateway started successfully")
@app.on_event("shutdown")
async def shutdown_event():
"""Application shutdown"""
logger.info("Shutting down API Gateway")
# Close shared Redis connection
await close_redis()
# Clean up service discovery
# await service_discovery.cleanup()
logger.info("API Gateway shutdown complete")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"service": "api-gateway",
"version": "1.0.0",
"timestamp": time.time()
}
@app.get("/metrics")
async def metrics():
"""Metrics endpoint for monitoring"""
return {"metrics": "enabled"}
# ================================================================
# SERVER-SENT EVENTS (SSE) HELPER FUNCTIONS
# ================================================================
def _get_subscription_channels(tenant_id: str, channel_filters: list) -> list:
"""
Determine which Redis channels to subscribe to based on filters.
Args:
tenant_id: Tenant identifier
channel_filters: List of channel patterns (e.g., ["inventory.alerts", "*.notifications"])
Returns:
List of full channel names to subscribe to
Examples:
>>> _get_subscription_channels("abc", ["inventory.alerts"])
["tenant:abc:inventory.alerts"]
>>> _get_subscription_channels("abc", ["*.alerts"])
["tenant:abc:inventory.alerts", "tenant:abc:production.alerts", ...]
>>> _get_subscription_channels("abc", [])
["tenant:abc:inventory.alerts", "tenant:abc:inventory.notifications", ...]
"""
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
all_classes = ["alerts", "notifications"]
channels = []
if not channel_filters:
# Subscribe to ALL channels (backward compatible)
for domain in all_domains:
for event_class in all_classes:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
# Also subscribe to recommendations (tenant-wide)
channels.append(f"tenant:{tenant_id}:recommendations")
# Also subscribe to legacy channel for backward compatibility
channels.append(f"alerts:{tenant_id}")
return channels
# Parse filters and expand wildcards
for filter_pattern in channel_filters:
if filter_pattern == "*.*":
# All channels
for domain in all_domains:
for event_class in all_classes:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
channels.append(f"tenant:{tenant_id}:recommendations")
elif filter_pattern.endswith(".*"):
# Domain wildcard (e.g., "inventory.*")
domain = filter_pattern.split(".")[0]
for event_class in all_classes:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
elif filter_pattern.startswith("*."):
# Class wildcard (e.g., "*.alerts")
event_class = filter_pattern.split(".")[1]
if event_class == "recommendations":
channels.append(f"tenant:{tenant_id}:recommendations")
else:
for domain in all_domains:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
elif filter_pattern == "recommendations":
# Recommendations channel
channels.append(f"tenant:{tenant_id}:recommendations")
else:
# Specific channel (e.g., "inventory.alerts")
channels.append(f"tenant:{tenant_id}:{filter_pattern}")
return list(set(channels)) # Remove duplicates
async def _load_initial_state(redis_client, tenant_id: str, channel_filters: list) -> list:
"""
Load initial state from Redis cache based on channel filters.
Args:
redis_client: Redis client
tenant_id: Tenant identifier
channel_filters: List of channel patterns
Returns:
List of initial events
"""
initial_events = []
try:
if not channel_filters:
# Load from legacy cache if no filters (backward compat)
legacy_cache_key = f"active_alerts:{tenant_id}"
cached_data = await redis_client.get(legacy_cache_key)
if cached_data:
return json.loads(cached_data)
# Also try loading from new domain-specific caches
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
all_classes = ["alerts", "notifications"]
for domain in all_domains:
for event_class in all_classes:
cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s"
cached_data = await redis_client.get(cache_key)
if cached_data:
events = json.loads(cached_data)
initial_events.extend(events)
# Load recommendations
recommendations_cache_key = f"active_events:{tenant_id}:recommendations"
cached_data = await redis_client.get(recommendations_cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
return initial_events
# Load based on specific filters
for filter_pattern in channel_filters:
# Extract domain and class from filter
if "." in filter_pattern:
parts = filter_pattern.split(".")
domain = parts[0] if parts[0] != "*" else None
event_class = parts[1] if len(parts) > 1 and parts[1] != "*" else None
if domain and event_class:
# Specific cache (e.g., "inventory.alerts")
cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
elif domain and not event_class:
# Domain wildcard (e.g., "inventory.*")
for ec in ["alerts", "notifications"]:
cache_key = f"active_events:{tenant_id}:{domain}.{ec}"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
elif not domain and event_class:
# Class wildcard (e.g., "*.alerts")
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
for d in all_domains:
cache_key = f"active_events:{tenant_id}:{d}.{event_class}s"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
elif filter_pattern == "recommendations":
cache_key = f"active_events:{tenant_id}:recommendations"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
return initial_events
except Exception as e:
logger.error(f"Error loading initial state for tenant {tenant_id}: {e}")
return []
def _determine_event_type(event_data: dict) -> str:
"""
Determine SSE event type from event data.
Args:
event_data: Event data dictionary
Returns:
SSE event type: 'alert', 'notification', or 'recommendation'
"""
# New event architecture uses 'event_class'
if 'event_class' in event_data:
return event_data['event_class'] # 'alert', 'notification', or 'recommendation'
# Legacy format uses 'item_type'
if 'item_type' in event_data:
if event_data['item_type'] == 'recommendation':
return 'recommendation'
else:
return 'alert'
# Default to 'alert' for backward compatibility
return 'alert'
# ================================================================
# SERVER-SENT EVENTS (SSE) ENDPOINT
# ================================================================
@app.get("/api/events")
async def events_stream(
request: Request,
tenant_id: str,
channels: str = None # Comma-separated channel filters (e.g., "inventory.alerts,production.notifications")
):
"""
Server-Sent Events stream for real-time notifications with multi-channel support.
Authentication is handled by auth middleware via query param token.
User context is available in request.state.user (injected by middleware).
Query Parameters:
tenant_id: Tenant identifier (required)
channels: Comma-separated channel filters (optional)
Examples:
- "inventory.alerts,production.notifications" - Specific channels
- "*.alerts" - All alert channels
- "inventory.*" - All inventory events
- None - All channels (default, backward compatible)
New channel pattern: tenant:{tenant_id}:{domain}.{class}
Examples:
- tenant:abc:inventory.alerts
- tenant:abc:production.notifications
- tenant:abc:recommendations
Legacy channel (backward compat): alerts:{tenant_id}
"""
global redis_client
if not redis_client:
raise HTTPException(status_code=503, detail="SSE service unavailable")
# Extract user context from request state (set by auth middleware)
user_context = request.state.user
user_id = user_context.get('user_id')
email = user_context.get('email')
# Validate tenant_id parameter
if not tenant_id:
raise HTTPException(status_code=400, detail="tenant_id query parameter is required")
# Parse channel filters
channel_filters = []
if channels:
channel_filters = [c.strip() for c in channels.split(',') if c.strip()]
logger.info(f"SSE connection request for user {email}, tenant {tenant_id}, channels: {channel_filters or 'all'}")
async def event_generator():
"""Generate server-sent events from Redis pub/sub with multi-channel support"""
pubsub = None
try:
# Create pubsub connection with resource monitoring
pubsub = redis_client.pubsub()
logger.debug(f"Created Redis pubsub connection for tenant: {tenant_id}")
# Monitor connection count
try:
connection_info = await redis_client.info('clients')
connected_clients = connection_info.get('connected_clients', 'unknown')
logger.debug(f"Redis connected clients: {connected_clients}")
except Exception:
# Don't fail if we can't get connection info
pass
# Determine which channels to subscribe to
subscription_channels = _get_subscription_channels(tenant_id, channel_filters)
# Subscribe to all determined channels
if subscription_channels:
await pubsub.subscribe(*subscription_channels)
logger.info(f"Subscribed to {len(subscription_channels)} channels for tenant {tenant_id}")
else:
# Fallback to legacy channel if no channels specified
legacy_channel = f"alerts:{tenant_id}"
await pubsub.subscribe(legacy_channel)
logger.info(f"Subscribed to legacy channel: {legacy_channel}")
# Send initial connection event
yield f"event: connection\n"
yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'channels': subscription_channels or ['all'], 'timestamp': time.time()})}\n\n"
# Fetch and send initial state from cache (domain-specific or legacy)
initial_events = await _load_initial_state(redis_client, tenant_id, channel_filters)
if initial_events:
logger.info(f"Sending {len(initial_events)} initial events to tenant {tenant_id}")
yield f"event: initial_state\n"
yield f"data: {json.dumps(initial_events)}\n\n"
else:
# Send empty initial state for compatibility
yield f"event: initial_state\n"
yield f"data: {json.dumps([])}\n\n"
heartbeat_counter = 0
while True:
# Check if client has disconnected
if await request.is_disconnected():
logger.info(f"SSE client disconnected for tenant: {tenant_id}")
break
try:
# Get message from Redis with timeout
message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0)
if message and message['type'] == 'message':
# Forward the event from Redis
event_data = json.loads(message['data'])
# Determine event type for SSE
event_type = _determine_event_type(event_data)
# Add channel metadata for frontend routing
event_data['_channel'] = message['channel'].decode('utf-8') if isinstance(message['channel'], bytes) else message['channel']
yield f"event: {event_type}\n"
yield f"data: {json.dumps(event_data)}\n\n"
logger.debug(f"SSE event sent to tenant {tenant_id}: {event_type} - {event_data.get('title')}")
except asyncio.TimeoutError:
# Send heartbeat every 10 timeouts (100 seconds)
heartbeat_counter += 1
if heartbeat_counter >= 10:
yield f"event: heartbeat\n"
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n"
heartbeat_counter = 0
except asyncio.CancelledError:
logger.info(f"SSE connection cancelled for tenant: {tenant_id}")
except Exception as e:
logger.error(f"SSE error for tenant {tenant_id}: {e}", exc_info=True)
finally:
try:
if pubsub:
try:
# Unsubscribe from all channels
await pubsub.unsubscribe()
logger.debug(f"Unsubscribed from Redis channels for tenant: {tenant_id}")
except Exception as unsubscribe_error:
logger.error(f"Failed to unsubscribe Redis pubsub for tenant {tenant_id}: {unsubscribe_error}")
try:
# Close pubsub connection
await pubsub.close()
logger.debug(f"Closed Redis pubsub connection for tenant: {tenant_id}")
except Exception as close_error:
logger.error(f"Failed to close Redis pubsub for tenant {tenant_id}: {close_error}")
logger.info(f"SSE connection closed for tenant: {tenant_id}")
except Exception as finally_error:
logger.error(f"Error in SSE cleanup for tenant {tenant_id}: {finally_error}")
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control",
}
)
# ================================================================
# WEBSOCKET ROUTING FOR TRAINING SERVICE
# ================================================================
@app.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
"""
Simple WebSocket proxy with token verification only.
Validates the token and forwards the connection to the training service.
"""
# Get token from query params
token = websocket.query_params.get("token")
if not token:
logger.warning("WebSocket proxy rejected - missing token",
job_id=job_id,
tenant_id=tenant_id)
await websocket.accept()
await websocket.close(code=1008, reason="Authentication token required")
return
# Verify token
from shared.auth.jwt_handler import JWTHandler
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload or not payload.get('user_id'):
logger.warning("WebSocket proxy rejected - invalid token",
job_id=job_id,
tenant_id=tenant_id)
await websocket.accept()
await websocket.close(code=1008, reason="Invalid token")
return
logger.info("WebSocket proxy - token verified",
user_id=payload.get('user_id'),
tenant_id=tenant_id,
job_id=job_id)
except Exception as e:
logger.warning("WebSocket proxy - token verification failed",
job_id=job_id,
error=str(e))
await websocket.accept()
await websocket.close(code=1008, reason="Token verification failed")
return
# Accept the connection
await websocket.accept()
# Build WebSocket URL to training service
training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/')
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
logger.info("Gateway proxying WebSocket to training service",
job_id=job_id,
training_ws_url=training_ws_url.replace(token, '***'))
training_ws = None
try:
# Connect to training service WebSocket
import websockets
from websockets.protocol import State
training_ws = await websockets.connect(
training_ws_url,
ping_interval=120, # Send ping every 2 minutes (tolerates long training operations)
ping_timeout=60, # Wait up to 1 minute for pong (graceful timeout)
close_timeout=60, # Increase close timeout for graceful shutdown
open_timeout=30
)
logger.info("Gateway connected to training service WebSocket", job_id=job_id)
async def forward_frontend_to_training():
"""Forward messages from frontend to training service"""
try:
while training_ws and training_ws.state == State.OPEN:
data = await websocket.receive()
if data.get("type") == "websocket.receive":
if "text" in data:
await training_ws.send(data["text"])
elif "bytes" in data:
await training_ws.send(data["bytes"])
elif data.get("type") == "websocket.disconnect":
break
except Exception as e:
logger.debug("Frontend to training forward ended", error=str(e))
async def forward_training_to_frontend():
"""Forward messages from training service to frontend"""
message_count = 0
try:
while training_ws and training_ws.state == State.OPEN:
message = await training_ws.recv()
await websocket.send_text(message)
message_count += 1
# Log every 10th message to track connectivity
if message_count % 10 == 0:
logger.debug("WebSocket proxy active",
job_id=job_id,
messages_forwarded=message_count)
except Exception as e:
logger.info("Training to frontend forward ended",
job_id=job_id,
messages_forwarded=message_count,
error=str(e))
# Run both forwarding tasks concurrently
await asyncio.gather(
forward_frontend_to_training(),
forward_training_to_frontend(),
return_exceptions=True
)
except Exception as e:
logger.error("WebSocket proxy error", job_id=job_id, error=str(e))
finally:
# Cleanup
if training_ws and training_ws.state == State.OPEN:
try:
await training_ws.close()
except:
pass
try:
if not websocket.client_state.name == 'DISCONNECTED':
await websocket.close(code=1000, reason="Proxy closed")
except:
pass
logger.info("WebSocket proxy connection closed", job_id=job_id)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)