Files
bakery-ia/gateway/app/main.py

686 lines
28 KiB
Python
Raw Normal View History

"""
API Gateway - Central entry point for all microservices
Handles routing, authentication, rate limiting, and cross-cutting concerns
"""
import asyncio
import json
2025-07-18 14:41:39 +02:00
import structlog
2025-12-14 19:05:37 +01:00
import resource
import os
2025-08-14 13:26:59 +02:00
from fastapi import FastAPI, Request, HTTPException, Depends, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
import httpx
import time
from shared.redis_utils import initialize_redis, close_redis, get_redis_client
from typing import Dict, Any
from app.core.config import settings
from app.middleware.request_id import RequestIDMiddleware
2025-07-17 19:54:04 +02:00
from app.middleware.auth import AuthMiddleware
from app.middleware.logging import LoggingMiddleware
from app.middleware.rate_limit import RateLimitMiddleware
2025-12-18 13:26:32 +01:00
from app.middleware.rate_limiting import APIRateLimitMiddleware
2025-09-21 13:27:50 +02:00
from app.middleware.subscription import SubscriptionMiddleware
2025-10-03 14:09:34 +02:00
from app.middleware.demo_middleware import DemoMiddleware
2025-10-16 07:28:04 +02:00
from app.middleware.read_only_mode import ReadOnlyModeMiddleware
from app.routes import auth, tenant, notification, nominatim, subscription, demo, pos, geocoding, poi_context
from shared.monitoring.logging import setup_logging
from shared.monitoring.metrics import MetricsCollector
# Setup logging
setup_logging("gateway", settings.LOG_LEVEL)
2025-07-18 14:41:39 +02:00
logger = structlog.get_logger()
2025-12-14 19:05:37 +01:00
# Check file descriptor limits and warn if too low
try:
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
if soft_limit < 1024:
logger.warning(f"Low file descriptor limit detected: {soft_limit}. Gateway may experience 'too many open files' errors.")
logger.warning(f"Recommended: Increase limit with 'ulimit -n 4096' or higher for production.")
if soft_limit < 256:
logger.error(f"Critical: File descriptor limit ({soft_limit}) is too low for gateway operation!")
else:
logger.info(f"File descriptor limit: {soft_limit} (sufficient)")
except Exception as e:
logger.debug(f"Could not check file descriptor limits: {e}")
# Check and log current working directory and permissions
try:
cwd = os.getcwd()
logger.info(f"Current working directory: {cwd}")
# Check if we can write to common log locations
test_locations = ["/var/log", "./logs", "."]
for location in test_locations:
try:
test_file = os.path.join(location, ".gateway_permission_test")
with open(test_file, 'w') as f:
f.write("test")
os.remove(test_file)
logger.info(f"Write permission confirmed for: {location}")
except Exception as e:
logger.warning(f"Cannot write to {location}: {e}")
except Exception as e:
logger.debug(f"Could not check directory permissions: {e}")
# Create FastAPI app
app = FastAPI(
title="Bakery Forecasting API Gateway",
description="Central API Gateway for bakery forecasting microservices",
version="1.0.0",
docs_url="/docs",
2025-10-21 19:50:07 +02:00
redoc_url="/redoc",
redirect_slashes=False # Disable automatic trailing slash redirects
)
# Initialize metrics collector
metrics_collector = MetricsCollector("gateway")
# Redis client for SSE streaming
redis_client = None
2025-07-17 19:54:04 +02:00
# CORS middleware - Add first
app.add_middleware(
CORSMiddleware,
2025-07-17 19:46:41 +02:00
allow_origins=settings.CORS_ORIGINS_LIST,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
2025-10-03 14:09:34 +02:00
# Custom middleware - Add in REVERSE order (last added = first executed)
2025-12-18 13:26:32 +01:00
# Execution order: RequestIDMiddleware -> DemoMiddleware -> AuthMiddleware -> ReadOnlyModeMiddleware -> SubscriptionMiddleware -> APIRateLimitMiddleware -> RateLimitMiddleware -> LoggingMiddleware
app.add_middleware(LoggingMiddleware) # Executes 8th (outermost)
app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 7th - Simple rate limit
# Note: APIRateLimitMiddleware will be added on startup with Redis client
2025-10-16 07:28:04 +02:00
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 5th
app.add_middleware(ReadOnlyModeMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 4th - Enforce read-only mode
app.add_middleware(AuthMiddleware) # Executes 3rd - Checks for demo context
app.add_middleware(DemoMiddleware) # Executes 2nd - Sets demo user context
app.add_middleware(RequestIDMiddleware) # Executes 1st (innermost) - Generates request ID for tracing
# Include routers
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"])
2025-09-21 13:27:50 +02:00
app.include_router(subscription.router, prefix="/api/v1", tags=["subscriptions"])
app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"])
2025-07-22 17:01:12 +02:00
app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"])
app.include_router(geocoding.router, prefix="/api/v1/geocoding", tags=["geocoding"])
2025-11-14 07:23:56 +01:00
# app.include_router(poi_context.router, prefix="/api/v1/poi-context", tags=["poi-context"]) # Removed to implement tenant-based architecture
2025-10-07 07:15:07 +02:00
app.include_router(pos.router, prefix="/api/v1/pos", tags=["pos"])
2025-10-03 14:09:34 +02:00
app.include_router(demo.router, prefix="/api/v1", tags=["demo"])
@app.on_event("startup")
async def startup_event():
"""Application startup"""
global redis_client
logger.info("Starting API Gateway")
# Initialize shared Redis connection
try:
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
redis_client = await get_redis_client()
logger.info("Connected to Redis for SSE streaming")
2025-12-18 13:26:32 +01:00
# Add API rate limiting middleware with Redis client
app.add_middleware(APIRateLimitMiddleware, redis_client=redis_client)
logger.info("API rate limiting middleware enabled with subscription-based quotas")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
2025-12-18 13:26:32 +01:00
logger.warning("API rate limiting middleware will fail open (allow all requests)")
2025-08-02 21:56:25 +02:00
2025-07-26 21:10:54 +02:00
metrics_collector.register_counter(
"gateway_auth_requests_total",
2025-08-02 21:56:25 +02:00
"Total authentication requests"
2025-07-26 21:10:54 +02:00
)
metrics_collector.register_counter(
"gateway_auth_responses_total",
2025-08-02 21:56:25 +02:00
"Total authentication responses"
)
metrics_collector.register_counter(
"gateway_auth_errors_total",
"Total authentication errors"
2025-07-26 21:10:54 +02:00
)
2025-08-02 21:56:25 +02:00
2025-07-26 21:10:54 +02:00
metrics_collector.register_histogram(
"gateway_request_duration_seconds",
2025-08-02 21:56:25 +02:00
"Request duration in seconds"
2025-07-26 21:10:54 +02:00
)
2025-08-02 21:56:25 +02:00
logger.info("Metrics registered successfully")
2025-07-26 21:10:54 +02:00
metrics_collector.start_metrics_server(8080)
logger.info("API Gateway started successfully")
@app.on_event("shutdown")
async def shutdown_event():
"""Application shutdown"""
logger.info("Shutting down API Gateway")
# Close shared Redis connection
await close_redis()
# Clean up service discovery
2025-07-18 12:57:13 +02:00
# await service_discovery.cleanup()
logger.info("API Gateway shutdown complete")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
2025-07-17 19:54:04 +02:00
"service": "api-gateway",
"version": "1.0.0",
"timestamp": time.time()
}
@app.get("/metrics")
2025-07-17 19:54:04 +02:00
async def metrics():
"""Metrics endpoint for monitoring"""
return {"metrics": "enabled"}
# ================================================================
# SERVER-SENT EVENTS (SSE) HELPER FUNCTIONS
# ================================================================
def _get_subscription_channels(tenant_id: str, channel_filters: list) -> list:
"""
Determine which Redis channels to subscribe to based on filters.
Args:
tenant_id: Tenant identifier
channel_filters: List of channel patterns (e.g., ["inventory.alerts", "*.notifications"])
Returns:
List of full channel names to subscribe to
Examples:
>>> _get_subscription_channels("abc", ["inventory.alerts"])
["tenant:abc:inventory.alerts"]
>>> _get_subscription_channels("abc", ["*.alerts"])
["tenant:abc:inventory.alerts", "tenant:abc:production.alerts", ...]
>>> _get_subscription_channels("abc", [])
["tenant:abc:inventory.alerts", "tenant:abc:inventory.notifications", ...]
"""
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
all_classes = ["alerts", "notifications"]
channels = []
if not channel_filters:
# Subscribe to ALL channels (backward compatible)
for domain in all_domains:
for event_class in all_classes:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
# Also subscribe to recommendations (tenant-wide)
channels.append(f"tenant:{tenant_id}:recommendations")
# Also subscribe to legacy channel for backward compatibility
channels.append(f"alerts:{tenant_id}")
return channels
# Parse filters and expand wildcards
for filter_pattern in channel_filters:
if filter_pattern == "*.*":
# All channels
for domain in all_domains:
for event_class in all_classes:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
channels.append(f"tenant:{tenant_id}:recommendations")
elif filter_pattern.endswith(".*"):
# Domain wildcard (e.g., "inventory.*")
domain = filter_pattern.split(".")[0]
for event_class in all_classes:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
elif filter_pattern.startswith("*."):
# Class wildcard (e.g., "*.alerts")
event_class = filter_pattern.split(".")[1]
if event_class == "recommendations":
channels.append(f"tenant:{tenant_id}:recommendations")
else:
for domain in all_domains:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
elif filter_pattern == "recommendations":
# Recommendations channel
channels.append(f"tenant:{tenant_id}:recommendations")
else:
# Specific channel (e.g., "inventory.alerts")
channels.append(f"tenant:{tenant_id}:{filter_pattern}")
return list(set(channels)) # Remove duplicates
async def _load_initial_state(redis_client, tenant_id: str, channel_filters: list) -> list:
"""
Load initial state from Redis cache based on channel filters.
Args:
redis_client: Redis client
tenant_id: Tenant identifier
channel_filters: List of channel patterns
Returns:
List of initial events
"""
initial_events = []
try:
if not channel_filters:
# Load from legacy cache if no filters (backward compat)
legacy_cache_key = f"active_alerts:{tenant_id}"
cached_data = await redis_client.get(legacy_cache_key)
if cached_data:
return json.loads(cached_data)
# Also try loading from new domain-specific caches
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
all_classes = ["alerts", "notifications"]
for domain in all_domains:
for event_class in all_classes:
cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s"
cached_data = await redis_client.get(cache_key)
if cached_data:
events = json.loads(cached_data)
initial_events.extend(events)
# Load recommendations
recommendations_cache_key = f"active_events:{tenant_id}:recommendations"
cached_data = await redis_client.get(recommendations_cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
return initial_events
# Load based on specific filters
for filter_pattern in channel_filters:
# Extract domain and class from filter
if "." in filter_pattern:
parts = filter_pattern.split(".")
domain = parts[0] if parts[0] != "*" else None
event_class = parts[1] if len(parts) > 1 and parts[1] != "*" else None
if domain and event_class:
# Specific cache (e.g., "inventory.alerts")
cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
elif domain and not event_class:
# Domain wildcard (e.g., "inventory.*")
for ec in ["alerts", "notifications"]:
cache_key = f"active_events:{tenant_id}:{domain}.{ec}"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
elif not domain and event_class:
# Class wildcard (e.g., "*.alerts")
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
for d in all_domains:
cache_key = f"active_events:{tenant_id}:{d}.{event_class}s"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
elif filter_pattern == "recommendations":
cache_key = f"active_events:{tenant_id}:recommendations"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
return initial_events
except Exception as e:
logger.error(f"Error loading initial state for tenant {tenant_id}: {e}")
return []
def _determine_event_type(event_data: dict) -> str:
"""
Determine SSE event type from event data.
Args:
event_data: Event data dictionary
Returns:
SSE event type: 'alert', 'notification', or 'recommendation'
"""
# New event architecture uses 'event_class'
if 'event_class' in event_data:
return event_data['event_class'] # 'alert', 'notification', or 'recommendation'
# Legacy format uses 'item_type'
if 'item_type' in event_data:
if event_data['item_type'] == 'recommendation':
return 'recommendation'
else:
return 'alert'
# Default to 'alert' for backward compatibility
return 'alert'
# ================================================================
# SERVER-SENT EVENTS (SSE) ENDPOINT
# ================================================================
@app.get("/api/events")
async def events_stream(
request: Request,
tenant_id: str,
channels: str = None # Comma-separated channel filters (e.g., "inventory.alerts,production.notifications")
):
"""
Server-Sent Events stream for real-time notifications with multi-channel support.
Authentication is handled by auth middleware via query param token.
User context is available in request.state.user (injected by middleware).
Query Parameters:
tenant_id: Tenant identifier (required)
channels: Comma-separated channel filters (optional)
Examples:
- "inventory.alerts,production.notifications" - Specific channels
- "*.alerts" - All alert channels
- "inventory.*" - All inventory events
- None - All channels (default, backward compatible)
New channel pattern: tenant:{tenant_id}:{domain}.{class}
Examples:
- tenant:abc:inventory.alerts
- tenant:abc:production.notifications
- tenant:abc:recommendations
Legacy channel (backward compat): alerts:{tenant_id}
"""
global redis_client
if not redis_client:
raise HTTPException(status_code=503, detail="SSE service unavailable")
# Extract user context from request state (set by auth middleware)
user_context = request.state.user
user_id = user_context.get('user_id')
email = user_context.get('email')
# Validate tenant_id parameter
if not tenant_id:
raise HTTPException(status_code=400, detail="tenant_id query parameter is required")
# Parse channel filters
channel_filters = []
if channels:
channel_filters = [c.strip() for c in channels.split(',') if c.strip()]
logger.info(f"SSE connection request for user {email}, tenant {tenant_id}, channels: {channel_filters or 'all'}")
async def event_generator():
"""Generate server-sent events from Redis pub/sub with multi-channel support"""
pubsub = None
try:
2025-12-14 19:05:37 +01:00
# Create pubsub connection with resource monitoring
pubsub = redis_client.pubsub()
2025-12-14 19:05:37 +01:00
logger.debug(f"Created Redis pubsub connection for tenant: {tenant_id}")
# Monitor connection count
try:
connection_info = await redis_client.info('clients')
connected_clients = connection_info.get('connected_clients', 'unknown')
logger.debug(f"Redis connected clients: {connected_clients}")
except Exception:
# Don't fail if we can't get connection info
pass
# Determine which channels to subscribe to
subscription_channels = _get_subscription_channels(tenant_id, channel_filters)
# Subscribe to all determined channels
if subscription_channels:
await pubsub.subscribe(*subscription_channels)
logger.info(f"Subscribed to {len(subscription_channels)} channels for tenant {tenant_id}")
else:
# Fallback to legacy channel if no channels specified
legacy_channel = f"alerts:{tenant_id}"
await pubsub.subscribe(legacy_channel)
logger.info(f"Subscribed to legacy channel: {legacy_channel}")
2025-10-19 19:22:37 +02:00
# Send initial connection event
yield f"event: connection\n"
yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'channels': subscription_channels or ['all'], 'timestamp': time.time()})}\n\n"
# Fetch and send initial state from cache (domain-specific or legacy)
initial_events = await _load_initial_state(redis_client, tenant_id, channel_filters)
if initial_events:
logger.info(f"Sending {len(initial_events)} initial events to tenant {tenant_id}")
yield f"event: initial_state\n"
yield f"data: {json.dumps(initial_events)}\n\n"
else:
# Send empty initial state for compatibility
yield f"event: initial_state\n"
2025-10-19 19:22:37 +02:00
yield f"data: {json.dumps([])}\n\n"
heartbeat_counter = 0
while True:
# Check if client has disconnected
if await request.is_disconnected():
logger.info(f"SSE client disconnected for tenant: {tenant_id}")
break
try:
# Get message from Redis with timeout
message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0)
if message and message['type'] == 'message':
# Forward the event from Redis
event_data = json.loads(message['data'])
# Determine event type for SSE
event_type = _determine_event_type(event_data)
# Add channel metadata for frontend routing
event_data['_channel'] = message['channel'].decode('utf-8') if isinstance(message['channel'], bytes) else message['channel']
yield f"event: {event_type}\n"
yield f"data: {json.dumps(event_data)}\n\n"
logger.debug(f"SSE event sent to tenant {tenant_id}: {event_type} - {event_data.get('title')}")
except asyncio.TimeoutError:
# Send heartbeat every 10 timeouts (100 seconds)
heartbeat_counter += 1
if heartbeat_counter >= 10:
yield f"event: heartbeat\n"
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n"
heartbeat_counter = 0
except asyncio.CancelledError:
logger.info(f"SSE connection cancelled for tenant: {tenant_id}")
except Exception as e:
logger.error(f"SSE error for tenant {tenant_id}: {e}", exc_info=True)
finally:
2025-12-14 19:05:37 +01:00
try:
if pubsub:
try:
# Unsubscribe from all channels
await pubsub.unsubscribe()
logger.debug(f"Unsubscribed from Redis channels for tenant: {tenant_id}")
except Exception as unsubscribe_error:
logger.error(f"Failed to unsubscribe Redis pubsub for tenant {tenant_id}: {unsubscribe_error}")
try:
# Close pubsub connection
await pubsub.close()
logger.debug(f"Closed Redis pubsub connection for tenant: {tenant_id}")
except Exception as close_error:
logger.error(f"Failed to close Redis pubsub for tenant {tenant_id}: {close_error}")
logger.info(f"SSE connection closed for tenant: {tenant_id}")
except Exception as finally_error:
logger.error(f"Error in SSE cleanup for tenant {tenant_id}: {finally_error}")
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control",
}
)
2025-08-14 13:26:59 +02:00
# ================================================================
# WEBSOCKET ROUTING FOR TRAINING SERVICE
# ================================================================
2025-10-07 07:15:07 +02:00
@app.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
2025-08-14 13:26:59 +02:00
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
2025-10-07 07:15:07 +02:00
"""
Simple WebSocket proxy with token verification only.
Validates the token and forwards the connection to the training service.
2025-10-07 07:15:07 +02:00
"""
# Get token from query params
2025-08-14 13:26:59 +02:00
token = websocket.query_params.get("token")
if not token:
logger.warning("WebSocket proxy rejected - missing token",
job_id=job_id,
tenant_id=tenant_id)
2025-10-07 07:15:07 +02:00
await websocket.accept()
2025-08-14 13:26:59 +02:00
await websocket.close(code=1008, reason="Authentication token required")
return
# Verify token
from shared.auth.jwt_handler import JWTHandler
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload or not payload.get('user_id'):
logger.warning("WebSocket proxy rejected - invalid token",
job_id=job_id,
tenant_id=tenant_id)
await websocket.accept()
await websocket.close(code=1008, reason="Invalid token")
return
logger.info("WebSocket proxy - token verified",
user_id=payload.get('user_id'),
tenant_id=tenant_id,
job_id=job_id)
except Exception as e:
logger.warning("WebSocket proxy - token verification failed",
job_id=job_id,
error=str(e))
await websocket.accept()
await websocket.close(code=1008, reason="Token verification failed")
return
# Accept the connection
await websocket.accept()
# Build WebSocket URL to training service
2025-08-14 13:26:59 +02:00
training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/')
2025-08-15 17:53:59 +02:00
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
2025-10-06 15:27:01 +02:00
training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
logger.info("Gateway proxying WebSocket to training service",
job_id=job_id,
training_ws_url=training_ws_url.replace(token, '***'))
training_ws = None
2025-08-14 13:26:59 +02:00
try:
2025-10-07 07:15:07 +02:00
# Connect to training service WebSocket
2025-08-15 17:53:59 +02:00
import websockets
from websockets.protocol import State
training_ws = await websockets.connect(
training_ws_url,
ping_interval=120, # Send ping every 2 minutes (tolerates long training operations)
ping_timeout=60, # Wait up to 1 minute for pong (graceful timeout)
close_timeout=60, # Increase close timeout for graceful shutdown
open_timeout=30
)
logger.info("Gateway connected to training service WebSocket", job_id=job_id)
async def forward_frontend_to_training():
2025-10-07 07:15:07 +02:00
"""Forward messages from frontend to training service"""
try:
while training_ws and training_ws.state == State.OPEN:
2025-10-07 07:15:07 +02:00
data = await websocket.receive()
if data.get("type") == "websocket.receive":
if "text" in data:
await training_ws.send(data["text"])
elif "bytes" in data:
await training_ws.send(data["bytes"])
elif data.get("type") == "websocket.disconnect":
break
except Exception as e:
logger.debug("Frontend to training forward ended", error=str(e))
async def forward_training_to_frontend():
2025-10-07 07:15:07 +02:00
"""Forward messages from training service to frontend"""
message_count = 0
try:
while training_ws and training_ws.state == State.OPEN:
2025-10-07 07:15:07 +02:00
message = await training_ws.recv()
await websocket.send_text(message)
message_count += 1
# Log every 10th message to track connectivity
if message_count % 10 == 0:
logger.debug("WebSocket proxy active",
job_id=job_id,
messages_forwarded=message_count)
except Exception as e:
logger.info("Training to frontend forward ended",
job_id=job_id,
messages_forwarded=message_count,
error=str(e))
2025-10-07 07:15:07 +02:00
# Run both forwarding tasks concurrently
await asyncio.gather(
forward_frontend_to_training(),
forward_training_to_frontend(),
2025-10-07 07:15:07 +02:00
return_exceptions=True
)
2025-08-14 13:26:59 +02:00
except Exception as e:
logger.error("WebSocket proxy error", job_id=job_id, error=str(e))
2025-08-15 17:53:59 +02:00
finally:
# Cleanup
if training_ws and training_ws.state == State.OPEN:
try:
await training_ws.close()
except:
pass
try:
if not websocket.client_state.name == 'DISCONNECTED':
2025-10-07 07:15:07 +02:00
await websocket.close(code=1000, reason="Proxy closed")
except:
pass
logger.info("WebSocket proxy connection closed", job_id=job_id)
2025-08-14 13:26:59 +02:00
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)