2025-07-17 13:09:24 +02:00
|
|
|
"""
|
|
|
|
|
API Gateway - Central entry point for all microservices
|
|
|
|
|
Handles routing, authentication, rate limiting, and cross-cutting concerns
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import asyncio
|
2025-10-02 13:20:30 +02:00
|
|
|
import json
|
2025-07-18 14:41:39 +02:00
|
|
|
import structlog
|
2025-08-14 13:26:59 +02:00
|
|
|
from fastapi import FastAPI, Request, HTTPException, Depends, WebSocket, WebSocketDisconnect
|
2025-07-17 13:09:24 +02:00
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
2025-09-04 23:19:53 +02:00
|
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
2025-07-17 13:09:24 +02:00
|
|
|
import httpx
|
|
|
|
|
import time
|
2025-09-04 23:19:53 +02:00
|
|
|
import redis.asyncio as aioredis
|
2025-07-17 13:09:24 +02:00
|
|
|
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
from app.core.config import settings
|
|
|
|
|
from app.core.service_discovery import ServiceDiscovery
|
2025-07-17 19:54:04 +02:00
|
|
|
from app.middleware.auth import AuthMiddleware
|
|
|
|
|
from app.middleware.logging import LoggingMiddleware
|
|
|
|
|
from app.middleware.rate_limit import RateLimitMiddleware
|
2025-09-21 13:27:50 +02:00
|
|
|
from app.middleware.subscription import SubscriptionMiddleware
|
2025-10-03 14:09:34 +02:00
|
|
|
from app.middleware.demo_middleware import DemoMiddleware
|
|
|
|
|
from app.routes import auth, tenant, notification, nominatim, user, subscription, demo
|
2025-07-17 13:09:24 +02:00
|
|
|
from shared.monitoring.logging import setup_logging
|
|
|
|
|
from shared.monitoring.metrics import MetricsCollector
|
|
|
|
|
|
|
|
|
|
# Setup logging
|
|
|
|
|
setup_logging("gateway", settings.LOG_LEVEL)
|
2025-07-18 14:41:39 +02:00
|
|
|
logger = structlog.get_logger()
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
# Create FastAPI app
|
|
|
|
|
app = FastAPI(
|
|
|
|
|
title="Bakery Forecasting API Gateway",
|
|
|
|
|
description="Central API Gateway for bakery forecasting microservices",
|
|
|
|
|
version="1.0.0",
|
|
|
|
|
docs_url="/docs",
|
|
|
|
|
redoc_url="/redoc"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Initialize metrics collector
|
|
|
|
|
metrics_collector = MetricsCollector("gateway")
|
|
|
|
|
|
|
|
|
|
# Service discovery
|
|
|
|
|
service_discovery = ServiceDiscovery()
|
|
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
# Redis client for SSE streaming
|
|
|
|
|
redis_client = None
|
|
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
# CORS middleware - Add first
|
2025-07-17 13:09:24 +02:00
|
|
|
app.add_middleware(
|
|
|
|
|
CORSMiddleware,
|
2025-07-17 19:46:41 +02:00
|
|
|
allow_origins=settings.CORS_ORIGINS_LIST,
|
2025-07-17 13:09:24 +02:00
|
|
|
allow_credentials=True,
|
|
|
|
|
allow_methods=["*"],
|
|
|
|
|
allow_headers=["*"],
|
|
|
|
|
)
|
|
|
|
|
|
2025-10-03 14:09:34 +02:00
|
|
|
# Custom middleware - Add in REVERSE order (last added = first executed)
|
|
|
|
|
# Execution order: DemoMiddleware -> AuthMiddleware -> SubscriptionMiddleware -> RateLimitMiddleware -> LoggingMiddleware
|
|
|
|
|
app.add_middleware(LoggingMiddleware) # Executes 5th (outermost)
|
|
|
|
|
app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 4th
|
|
|
|
|
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 3rd
|
|
|
|
|
app.add_middleware(AuthMiddleware) # Executes 2nd - Checks for demo context
|
|
|
|
|
app.add_middleware(DemoMiddleware) # Executes 1st (innermost) - Sets demo user context FIRST
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
# Include routers
|
|
|
|
|
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
|
2025-07-26 19:15:18 +02:00
|
|
|
app.include_router(user.router, prefix="/api/v1/users", tags=["users"])
|
2025-07-17 13:09:24 +02:00
|
|
|
app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"])
|
2025-09-21 13:27:50 +02:00
|
|
|
app.include_router(subscription.router, prefix="/api/v1", tags=["subscriptions"])
|
2025-07-17 13:09:24 +02:00
|
|
|
app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"])
|
2025-07-22 17:01:12 +02:00
|
|
|
app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"])
|
2025-10-03 14:09:34 +02:00
|
|
|
app.include_router(demo.router, prefix="/api/v1", tags=["demo"])
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-09-09 12:02:41 +02:00
|
|
|
|
2025-07-17 13:09:24 +02:00
|
|
|
@app.on_event("startup")
|
|
|
|
|
async def startup_event():
|
|
|
|
|
"""Application startup"""
|
2025-09-04 23:19:53 +02:00
|
|
|
global redis_client
|
|
|
|
|
|
2025-07-17 13:09:24 +02:00
|
|
|
logger.info("Starting API Gateway")
|
|
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
# Connect to Redis for SSE streaming
|
|
|
|
|
try:
|
|
|
|
|
redis_client = aioredis.from_url(settings.REDIS_URL)
|
|
|
|
|
logger.info("Connected to Redis for SSE streaming")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to connect to Redis: {e}")
|
2025-08-02 21:56:25 +02:00
|
|
|
|
2025-07-26 21:10:54 +02:00
|
|
|
metrics_collector.register_counter(
|
|
|
|
|
"gateway_auth_requests_total",
|
2025-08-02 21:56:25 +02:00
|
|
|
"Total authentication requests"
|
2025-07-26 21:10:54 +02:00
|
|
|
)
|
|
|
|
|
metrics_collector.register_counter(
|
|
|
|
|
"gateway_auth_responses_total",
|
2025-08-02 21:56:25 +02:00
|
|
|
"Total authentication responses"
|
|
|
|
|
)
|
|
|
|
|
metrics_collector.register_counter(
|
|
|
|
|
"gateway_auth_errors_total",
|
|
|
|
|
"Total authentication errors"
|
2025-07-26 21:10:54 +02:00
|
|
|
)
|
2025-08-02 21:56:25 +02:00
|
|
|
|
2025-07-26 21:10:54 +02:00
|
|
|
metrics_collector.register_histogram(
|
|
|
|
|
"gateway_request_duration_seconds",
|
2025-08-02 21:56:25 +02:00
|
|
|
"Request duration in seconds"
|
2025-07-26 21:10:54 +02:00
|
|
|
)
|
|
|
|
|
|
2025-08-02 21:56:25 +02:00
|
|
|
logger.info("Metrics registered successfully")
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-26 21:10:54 +02:00
|
|
|
metrics_collector.start_metrics_server(8080)
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
logger.info("API Gateway started successfully")
|
|
|
|
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
|
|
|
async def shutdown_event():
|
|
|
|
|
"""Application shutdown"""
|
2025-09-04 23:19:53 +02:00
|
|
|
global redis_client
|
|
|
|
|
|
2025-07-17 13:09:24 +02:00
|
|
|
logger.info("Shutting down API Gateway")
|
|
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
# Close Redis connection
|
|
|
|
|
if redis_client:
|
|
|
|
|
await redis_client.close()
|
|
|
|
|
|
2025-07-17 13:09:24 +02:00
|
|
|
# Clean up service discovery
|
2025-07-18 12:57:13 +02:00
|
|
|
# await service_discovery.cleanup()
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
logger.info("API Gateway shutdown complete")
|
|
|
|
|
|
|
|
|
|
@app.get("/health")
|
|
|
|
|
async def health_check():
|
|
|
|
|
"""Health check endpoint"""
|
|
|
|
|
return {
|
|
|
|
|
"status": "healthy",
|
2025-07-17 19:54:04 +02:00
|
|
|
"service": "api-gateway",
|
2025-07-17 13:09:24 +02:00
|
|
|
"version": "1.0.0",
|
|
|
|
|
"timestamp": time.time()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@app.get("/metrics")
|
2025-07-17 19:54:04 +02:00
|
|
|
async def metrics():
|
|
|
|
|
"""Metrics endpoint for monitoring"""
|
|
|
|
|
return {"metrics": "enabled"}
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
# ================================================================
|
|
|
|
|
# SERVER-SENT EVENTS (SSE) ENDPOINT
|
|
|
|
|
# ================================================================
|
|
|
|
|
|
|
|
|
|
@app.get("/api/events")
|
2025-10-02 13:20:30 +02:00
|
|
|
async def events_stream(request: Request, tenant_id: str):
|
|
|
|
|
"""
|
|
|
|
|
Server-Sent Events stream for real-time notifications.
|
|
|
|
|
|
|
|
|
|
Authentication is handled by auth middleware via query param token.
|
|
|
|
|
User context is available in request.state.user (injected by middleware).
|
|
|
|
|
Tenant ID is provided by the frontend as a query parameter.
|
|
|
|
|
"""
|
2025-09-04 23:19:53 +02:00
|
|
|
global redis_client
|
2025-10-02 13:20:30 +02:00
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
if not redis_client:
|
|
|
|
|
raise HTTPException(status_code=503, detail="SSE service unavailable")
|
2025-10-02 13:20:30 +02:00
|
|
|
|
|
|
|
|
# Extract user context from request state (set by auth middleware)
|
|
|
|
|
user_context = request.state.user
|
|
|
|
|
user_id = user_context.get('user_id')
|
|
|
|
|
email = user_context.get('email')
|
|
|
|
|
|
|
|
|
|
# Validate tenant_id parameter
|
|
|
|
|
if not tenant_id:
|
|
|
|
|
raise HTTPException(status_code=400, detail="tenant_id query parameter is required")
|
|
|
|
|
|
|
|
|
|
logger.info(f"SSE connection request for user {email}, tenant {tenant_id}")
|
2025-09-04 23:19:53 +02:00
|
|
|
|
|
|
|
|
logger.info(f"SSE connection established for tenant: {tenant_id}")
|
|
|
|
|
|
|
|
|
|
async def event_generator():
|
|
|
|
|
"""Generate server-sent events from Redis pub/sub"""
|
|
|
|
|
pubsub = None
|
|
|
|
|
try:
|
|
|
|
|
# Subscribe to tenant-specific alert channel
|
|
|
|
|
pubsub = redis_client.pubsub()
|
|
|
|
|
channel_name = f"alerts:{tenant_id}"
|
|
|
|
|
await pubsub.subscribe(channel_name)
|
|
|
|
|
|
|
|
|
|
# Send initial connection event
|
|
|
|
|
yield f"event: connection\n"
|
2025-10-02 13:20:30 +02:00
|
|
|
yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'timestamp': time.time()})}\n\n"
|
2025-09-04 23:19:53 +02:00
|
|
|
|
|
|
|
|
heartbeat_counter = 0
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
# Check if client has disconnected
|
|
|
|
|
if await request.is_disconnected():
|
|
|
|
|
logger.info(f"SSE client disconnected for tenant: {tenant_id}")
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Get message from Redis with timeout
|
|
|
|
|
message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0)
|
|
|
|
|
|
|
|
|
|
if message and message['type'] == 'message':
|
|
|
|
|
# Forward the alert/notification from Redis
|
2025-10-02 13:20:30 +02:00
|
|
|
alert_data = json.loads(message['data'])
|
|
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
# Determine event type based on alert data
|
|
|
|
|
event_type = "notification"
|
|
|
|
|
if alert_data.get('item_type') == 'alert':
|
|
|
|
|
if alert_data.get('severity') in ['high', 'urgent']:
|
|
|
|
|
event_type = "inventory_alert"
|
|
|
|
|
else:
|
|
|
|
|
event_type = "notification"
|
|
|
|
|
elif alert_data.get('item_type') == 'recommendation':
|
|
|
|
|
event_type = "notification"
|
2025-10-02 13:20:30 +02:00
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
yield f"event: {event_type}\n"
|
2025-10-02 13:20:30 +02:00
|
|
|
yield f"data: {json.dumps(alert_data)}\n\n"
|
2025-09-04 23:19:53 +02:00
|
|
|
|
|
|
|
|
logger.debug(f"SSE message sent to tenant {tenant_id}: {alert_data.get('title')}")
|
|
|
|
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
# Send heartbeat every 10 timeouts (100 seconds)
|
|
|
|
|
heartbeat_counter += 1
|
|
|
|
|
if heartbeat_counter >= 10:
|
|
|
|
|
yield f"event: heartbeat\n"
|
2025-10-02 13:20:30 +02:00
|
|
|
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n"
|
2025-09-04 23:19:53 +02:00
|
|
|
heartbeat_counter = 0
|
|
|
|
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
logger.info(f"SSE connection cancelled for tenant: {tenant_id}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"SSE error for tenant {tenant_id}: {e}")
|
|
|
|
|
finally:
|
|
|
|
|
if pubsub:
|
|
|
|
|
await pubsub.unsubscribe()
|
|
|
|
|
await pubsub.close()
|
|
|
|
|
logger.info(f"SSE connection closed for tenant: {tenant_id}")
|
|
|
|
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
event_generator(),
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
headers={
|
|
|
|
|
"Cache-Control": "no-cache",
|
|
|
|
|
"Connection": "keep-alive",
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "Cache-Control",
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
2025-08-14 13:26:59 +02:00
|
|
|
# ================================================================
|
|
|
|
|
# WEBSOCKET ROUTING FOR TRAINING SERVICE
|
|
|
|
|
# ================================================================
|
|
|
|
|
|
|
|
|
|
@app.websocket("/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
|
|
|
|
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
|
2025-09-29 07:54:25 +02:00
|
|
|
"""WebSocket proxy that forwards connections directly to training service with enhanced token validation"""
|
2025-08-14 13:26:59 +02:00
|
|
|
await websocket.accept()
|
2025-09-29 07:54:25 +02:00
|
|
|
|
2025-08-14 13:26:59 +02:00
|
|
|
# Get token from query params
|
|
|
|
|
token = websocket.query_params.get("token")
|
|
|
|
|
if not token:
|
2025-08-15 17:53:59 +02:00
|
|
|
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
|
2025-08-14 13:26:59 +02:00
|
|
|
await websocket.close(code=1008, reason="Authentication token required")
|
|
|
|
|
return
|
2025-09-29 07:54:25 +02:00
|
|
|
|
|
|
|
|
# Validate token using auth middleware
|
|
|
|
|
from app.middleware.auth import jwt_handler
|
|
|
|
|
try:
|
|
|
|
|
payload = jwt_handler.verify_token(token)
|
|
|
|
|
if not payload:
|
|
|
|
|
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
|
|
|
|
|
await websocket.close(code=1008, reason="Invalid authentication token")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Check token expiration
|
|
|
|
|
import time
|
|
|
|
|
if payload.get('exp', 0) < time.time():
|
|
|
|
|
logger.warning(f"WebSocket connection rejected - expired token for job {job_id}")
|
|
|
|
|
await websocket.close(code=1008, reason="Token expired")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
logger.info(f"WebSocket token validated for user {payload.get('email', 'unknown')}")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
|
|
|
|
|
await websocket.close(code=1008, reason="Token validation failed")
|
|
|
|
|
return
|
|
|
|
|
|
2025-08-15 17:53:59 +02:00
|
|
|
logger.info(f"Proxying WebSocket connection to training service for job {job_id}, tenant {tenant_id}")
|
2025-09-29 07:54:25 +02:00
|
|
|
|
2025-08-15 17:53:59 +02:00
|
|
|
# Build WebSocket URL to training service
|
2025-08-14 13:26:59 +02:00
|
|
|
training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/')
|
2025-08-15 17:53:59 +02:00
|
|
|
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
|
2025-10-06 15:27:01 +02:00
|
|
|
training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
|
2025-09-29 07:54:25 +02:00
|
|
|
|
|
|
|
|
training_ws = None
|
|
|
|
|
heartbeat_task = None
|
|
|
|
|
|
2025-08-14 13:26:59 +02:00
|
|
|
try:
|
2025-09-29 07:54:25 +02:00
|
|
|
# Connect to training service WebSocket with proper timeout configuration
|
2025-08-15 17:53:59 +02:00
|
|
|
import websockets
|
2025-09-29 07:54:25 +02:00
|
|
|
|
|
|
|
|
# Configure timeouts to coordinate with frontend (30s heartbeat) and training service
|
|
|
|
|
# DISABLE gateway-level ping to avoid dual-ping conflicts - let frontend handle ping/pong
|
|
|
|
|
training_ws = await websockets.connect(
|
|
|
|
|
training_ws_url,
|
|
|
|
|
ping_interval=None, # DISABLED: Let frontend handle ping/pong via message forwarding
|
|
|
|
|
ping_timeout=None, # DISABLED: No independent ping mechanism
|
|
|
|
|
close_timeout=15, # Reasonable close timeout
|
|
|
|
|
max_size=2**20, # 1MB max message size
|
|
|
|
|
max_queue=32 # Max queued messages
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Connected to training service WebSocket for job {job_id} with gateway ping DISABLED (frontend handles ping/pong)")
|
|
|
|
|
|
|
|
|
|
# Track connection state properly due to FastAPI WebSocket state propagation bug
|
|
|
|
|
connection_alive = True
|
|
|
|
|
last_activity = asyncio.get_event_loop().time()
|
|
|
|
|
|
|
|
|
|
async def check_connection_health():
|
|
|
|
|
"""Monitor connection health based on activity timestamps only - no WebSocket interference"""
|
|
|
|
|
nonlocal connection_alive, last_activity
|
|
|
|
|
|
|
|
|
|
while connection_alive:
|
2025-08-15 17:53:59 +02:00
|
|
|
try:
|
2025-09-29 07:54:25 +02:00
|
|
|
await asyncio.sleep(30) # Check every 30 seconds (aligned with frontend heartbeat)
|
|
|
|
|
current_time = asyncio.get_event_loop().time()
|
|
|
|
|
|
|
|
|
|
# Check if we haven't received any activity for too long
|
|
|
|
|
# Frontend sends ping every 30s, so 90s = 3 missed pings before considering dead
|
|
|
|
|
if current_time - last_activity > 90:
|
|
|
|
|
logger.warning(f"No frontend activity for 90s on job {job_id} - connection may be dead")
|
|
|
|
|
# Don't forcibly close - let the forwarding loops handle actual connection issues
|
|
|
|
|
# This is just monitoring/logging now
|
|
|
|
|
else:
|
|
|
|
|
logger.debug(f"Connection health OK for job {job_id} - last activity {int(current_time - last_activity)}s ago")
|
|
|
|
|
|
2025-08-15 17:53:59 +02:00
|
|
|
except Exception as e:
|
2025-09-29 07:54:25 +02:00
|
|
|
logger.error(f"Connection health monitoring error for job {job_id}: {e}")
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
async def forward_to_training():
|
|
|
|
|
"""Forward messages from frontend to training service with proper error handling"""
|
|
|
|
|
nonlocal connection_alive, last_activity
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
while connection_alive and training_ws and training_ws.open:
|
|
|
|
|
try:
|
|
|
|
|
# Use longer timeout to avoid conflicts with frontend 30s heartbeat
|
|
|
|
|
# Frontend sends ping every 30s, so we need to allow for some latency
|
2025-10-06 15:27:01 +02:00
|
|
|
data = await asyncio.wait_for(websocket.receive(), timeout=45.0)
|
2025-09-29 07:54:25 +02:00
|
|
|
last_activity = asyncio.get_event_loop().time()
|
|
|
|
|
|
2025-10-06 15:27:01 +02:00
|
|
|
# Handle different message types
|
|
|
|
|
if data.get("type") == "websocket.receive":
|
|
|
|
|
if "text" in data:
|
|
|
|
|
message = data["text"]
|
|
|
|
|
# Forward text messages to training service
|
|
|
|
|
await training_ws.send(message)
|
|
|
|
|
logger.debug(f"Forwarded message to training service for job {job_id}: {message[:100]}...")
|
|
|
|
|
elif "bytes" in data:
|
|
|
|
|
# Forward binary messages if needed
|
|
|
|
|
await training_ws.send(data["bytes"])
|
|
|
|
|
# Ping/pong frames are automatically handled by Starlette/FastAPI
|
2025-09-29 07:54:25 +02:00
|
|
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
# No message received in 45 seconds, continue loop
|
|
|
|
|
# This allows for frontend 30s heartbeat + network latency + processing time
|
|
|
|
|
continue
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error receiving from frontend for job {job_id}: {e}")
|
|
|
|
|
connection_alive = False
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in forward_to_training for job {job_id}: {e}")
|
|
|
|
|
connection_alive = False
|
|
|
|
|
|
|
|
|
|
async def forward_to_frontend():
|
|
|
|
|
"""Forward messages from training service to frontend with proper error handling"""
|
|
|
|
|
nonlocal connection_alive, last_activity
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
while connection_alive and training_ws and training_ws.open:
|
|
|
|
|
try:
|
|
|
|
|
# Use coordinated timeout - training service expects messages every 60s
|
|
|
|
|
# This should be longer than training service timeout to avoid premature closure
|
|
|
|
|
message = await asyncio.wait_for(training_ws.recv(), timeout=75.0)
|
|
|
|
|
last_activity = asyncio.get_event_loop().time()
|
|
|
|
|
|
|
|
|
|
# Forward the message to frontend
|
2025-08-15 17:53:59 +02:00
|
|
|
await websocket.send_text(message)
|
2025-09-29 07:54:25 +02:00
|
|
|
logger.debug(f"Forwarded message to frontend for job {job_id}: {message[:100]}...")
|
|
|
|
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
# No message received in 75 seconds, continue loop
|
|
|
|
|
# Training service sends heartbeats, so this indicates potential issues
|
|
|
|
|
continue
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error receiving from training service for job {job_id}: {e}")
|
|
|
|
|
connection_alive = False
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in forward_to_frontend for job {job_id}: {e}")
|
|
|
|
|
connection_alive = False
|
|
|
|
|
|
|
|
|
|
# Start connection health monitoring
|
|
|
|
|
heartbeat_task = asyncio.create_task(check_connection_health())
|
|
|
|
|
|
|
|
|
|
# Run both forwarding tasks concurrently with proper error handling
|
|
|
|
|
try:
|
2025-08-15 17:53:59 +02:00
|
|
|
await asyncio.gather(
|
|
|
|
|
forward_to_training(),
|
|
|
|
|
forward_to_frontend(),
|
|
|
|
|
return_exceptions=True
|
|
|
|
|
)
|
2025-09-29 07:54:25 +02:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error in WebSocket forwarding tasks for job {job_id}: {e}")
|
|
|
|
|
finally:
|
|
|
|
|
connection_alive = False
|
|
|
|
|
|
|
|
|
|
except websockets.exceptions.ConnectionClosedError as e:
|
|
|
|
|
logger.warning(f"Training service WebSocket connection closed for job {job_id}: {e}")
|
|
|
|
|
except websockets.exceptions.WebSocketException as e:
|
|
|
|
|
logger.error(f"WebSocket exception for job {job_id}: {e}")
|
2025-08-14 13:26:59 +02:00
|
|
|
except Exception as e:
|
2025-08-15 17:53:59 +02:00
|
|
|
logger.error(f"WebSocket proxy error for job {job_id}: {e}")
|
|
|
|
|
finally:
|
2025-09-29 07:54:25 +02:00
|
|
|
# Cleanup
|
|
|
|
|
if heartbeat_task and not heartbeat_task.done():
|
|
|
|
|
heartbeat_task.cancel()
|
|
|
|
|
try:
|
|
|
|
|
await heartbeat_task
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
if training_ws and not training_ws.closed:
|
|
|
|
|
try:
|
|
|
|
|
await training_ws.close()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Error closing training service WebSocket for job {job_id}: {e}")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if not websocket.client_state.name == 'DISCONNECTED':
|
|
|
|
|
await websocket.close(code=1000, reason="Proxy connection closed")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Error closing frontend WebSocket for job {job_id}: {e}")
|
|
|
|
|
|
|
|
|
|
logger.info(f"WebSocket proxy cleanup completed for job {job_id}")
|
2025-08-14 13:26:59 +02:00
|
|
|
|
2025-07-17 13:09:24 +02:00
|
|
|
if __name__ == "__main__":
|
|
|
|
|
import uvicorn
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|