Imporve monitoring 5

This commit is contained in:
Urtzi Alfaro
2026-01-09 23:14:12 +01:00
parent 22dab143ba
commit c05538cafb
23 changed files with 4737 additions and 1932 deletions

View File

@@ -8,13 +8,12 @@ import json
import structlog
import resource
import os
from fastapi import FastAPI, Request, HTTPException, Depends, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse, Response
import httpx
import time
from fastapi import Request, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.responses import StreamingResponse
import httpx
from shared.redis_utils import initialize_redis, close_redis, get_redis_client
from typing import Dict, Any
from shared.service_base import StandardFastAPIService
from app.core.config import settings
from app.middleware.request_id import RequestIDMiddleware
@@ -26,128 +25,84 @@ from app.middleware.subscription import SubscriptionMiddleware
from app.middleware.demo_middleware import DemoMiddleware
from app.middleware.read_only_mode import ReadOnlyModeMiddleware
from app.routes import auth, tenant, notification, nominatim, subscription, demo, pos, geocoding, poi_context
from shared.monitoring.logging import setup_logging
from shared.monitoring.metrics import MetricsCollector, add_metrics_middleware
from shared.monitoring.system_metrics import SystemMetricsCollector
# OpenTelemetry imports
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.sdk.resources import Resource
# Configure OpenTelemetry tracing
def setup_tracing(service_name: str = "gateway"):
"""Initialize OpenTelemetry tracing with OTLP exporter for Jaeger"""
# Create resource with service name
resource = Resource.create({"service.name": service_name})
# Configure OTLP exporter (sends to OpenTelemetry Collector)
otlp_exporter = OTLPSpanExporter(
endpoint=os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://signoz-otel-collector.bakery-ia.svc.cluster.local:4317"),
insecure=True # Use insecure connection for internal cluster communication
)
# Configure tracer provider
provider = TracerProvider(resource=resource)
processor = BatchSpanProcessor(otlp_exporter)
provider.add_span_processor(processor)
# Set global tracer provider
trace.set_tracer_provider(provider)
return provider
# Initialize tracing
tracer_provider = setup_tracing("gateway")
# Setup logging
setup_logging("gateway", settings.LOG_LEVEL)
# Initialize logger
logger = structlog.get_logger()
# Check file descriptor limits and warn if too low
# Check file descriptor limits
try:
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
if soft_limit < 1024:
logger.warning(f"Low file descriptor limit detected: {soft_limit}. Gateway may experience 'too many open files' errors.")
logger.warning(f"Recommended: Increase limit with 'ulimit -n 4096' or higher for production.")
if soft_limit < 256:
logger.error(f"Critical: File descriptor limit ({soft_limit}) is too low for gateway operation!")
logger.warning(f"Low file descriptor limit detected: {soft_limit}")
else:
logger.info(f"File descriptor limit: {soft_limit} (sufficient)")
except Exception as e:
logger.debug(f"Could not check file descriptor limits: {e}")
# Check and log current working directory and permissions
try:
cwd = os.getcwd()
logger.info(f"Current working directory: {cwd}")
# Check if we can write to common log locations
test_locations = ["/var/log", "./logs", "."]
for location in test_locations:
try:
test_file = os.path.join(location, ".gateway_permission_test")
with open(test_file, 'w') as f:
f.write("test")
os.remove(test_file)
logger.info(f"Write permission confirmed for: {location}")
except Exception as e:
logger.warning(f"Cannot write to {location}: {e}")
except Exception as e:
logger.debug(f"Could not check directory permissions: {e}")
# Create FastAPI app
app = FastAPI(
title="Bakery Forecasting API Gateway",
description="Central API Gateway for bakery forecasting microservices",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc",
redirect_slashes=False # Disable automatic trailing slash redirects
)
# Instrument FastAPI with OpenTelemetry
FastAPIInstrumentor.instrument_app(app)
# Instrument httpx for outgoing requests
HTTPXClientInstrumentor().instrument()
# Instrument Redis (will be active once redis client is initialized)
RedisInstrumentor().instrument()
# Initialize metrics collector
metrics_collector = MetricsCollector("gateway")
# Add metrics middleware to track HTTP requests
add_metrics_middleware(app, metrics_collector)
# Redis client for SSE streaming
redis_client = None
# CORS middleware - Add first
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS_LIST,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
class GatewayService(StandardFastAPIService):
"""Gateway Service with standardized monitoring setup"""
async def on_startup(self, app):
"""Custom startup logic for Gateway"""
global redis_client
# Initialize Redis
try:
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
redis_client = await get_redis_client()
logger.info("Connected to Redis for SSE streaming")
# Add API rate limiting middleware with Redis client
app.add_middleware(APIRateLimitMiddleware, redis_client=redis_client)
logger.info("API rate limiting middleware enabled")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
# Register custom metrics for gateway-specific operations
if self.telemetry_providers and self.telemetry_providers.app_metrics:
logger.info("Gateway-specific metrics tracking enabled")
await super().on_startup(app)
async def on_shutdown(self, app):
"""Custom shutdown logic for Gateway"""
await super().on_shutdown(app)
# Close Redis
await close_redis()
logger.info("Redis connection closed")
# Create service instance
service = GatewayService(
service_name="gateway",
app_name="Bakery Forecasting API Gateway",
description="Central API Gateway for bakery forecasting microservices",
version="1.0.0",
log_level=getattr(settings, 'LOG_LEVEL', 'INFO'),
cors_origins=settings.CORS_ORIGINS_LIST,
enable_metrics=True,
enable_health_checks=True,
enable_tracing=True,
enable_cors=True
)
# Custom middleware - Add in REVERSE order (last added = first executed)
# Create FastAPI app
app = service.create_app()
# Add gateway-specific middleware (in REVERSE order of execution)
# Execution order: RequestIDMiddleware -> DemoMiddleware -> AuthMiddleware -> ReadOnlyModeMiddleware -> SubscriptionMiddleware -> APIRateLimitMiddleware -> RateLimitMiddleware -> LoggingMiddleware
app.add_middleware(LoggingMiddleware) # Executes 8th (outermost)
app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 7th - Simple rate limit
# Note: APIRateLimitMiddleware will be added on startup with Redis client
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 5th
app.add_middleware(ReadOnlyModeMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 4th - Enforce read-only mode
app.add_middleware(AuthMiddleware) # Executes 3rd - Checks for demo context
app.add_middleware(DemoMiddleware) # Executes 2nd - Sets demo user context
app.add_middleware(RequestIDMiddleware) # Executes 1st (innermost) - Generates request ID for tracing
app.add_middleware(LoggingMiddleware)
app.add_middleware(RateLimitMiddleware, calls_per_minute=300)
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL)
app.add_middleware(ReadOnlyModeMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL)
app.add_middleware(AuthMiddleware)
app.add_middleware(DemoMiddleware)
app.add_middleware(RequestIDMiddleware)
# Include routers
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
@@ -156,114 +111,18 @@ app.include_router(subscription.router, prefix="/api/v1", tags=["subscriptions"]
app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"])
app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"])
app.include_router(geocoding.router, prefix="/api/v1/geocoding", tags=["geocoding"])
# app.include_router(poi_context.router, prefix="/api/v1/poi-context", tags=["poi-context"]) # Removed to implement tenant-based architecture
app.include_router(pos.router, prefix="/api/v1/pos", tags=["pos"])
app.include_router(demo.router, prefix="/api/v1", tags=["demo"])
@app.on_event("startup")
async def startup_event():
"""Application startup"""
global redis_client
logger.info("Starting API Gateway")
# Initialize shared Redis connection
try:
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
redis_client = await get_redis_client()
logger.info("Connected to Redis for SSE streaming")
# Add API rate limiting middleware with Redis client
app.add_middleware(APIRateLimitMiddleware, redis_client=redis_client)
logger.info("API rate limiting middleware enabled with subscription-based quotas")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
logger.warning("API rate limiting middleware will fail open (allow all requests)")
metrics_collector.register_counter(
"gateway_auth_requests_total",
"Total authentication requests"
)
metrics_collector.register_counter(
"gateway_auth_responses_total",
"Total authentication responses"
)
metrics_collector.register_counter(
"gateway_auth_errors_total",
"Total authentication errors"
)
metrics_collector.register_histogram(
"gateway_request_duration_seconds",
"Request duration in seconds"
)
logger.info("Metrics registered successfully")
# Note: Metrics are exported via OpenTelemetry OTLP to SigNoz - no metrics server needed
# Initialize system metrics collection
system_metrics = SystemMetricsCollector("gateway")
logger.info("System metrics collection started")
logger.info("Metrics export configured via OpenTelemetry OTLP")
logger.info("API Gateway started successfully")
@app.on_event("shutdown")
async def shutdown_event():
"""Application shutdown"""
logger.info("Shutting down API Gateway")
# Close shared Redis connection
await close_redis()
# Clean up service discovery
# await service_discovery.cleanup()
logger.info("API Gateway shutdown complete")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"service": "api-gateway",
"version": "1.0.0",
"timestamp": time.time()
}
# Note: Metrics are exported via OpenTelemetry OTLP to SigNoz
# The /metrics endpoint is not needed as metrics are pushed automatically
# ================================================================
# SERVER-SENT EVENTS (SSE) HELPER FUNCTIONS
# ================================================================
def _get_subscription_channels(tenant_id: str, channel_filters: list) -> list:
"""
Determine which Redis channels to subscribe to based on filters.
Args:
tenant_id: Tenant identifier
channel_filters: List of channel patterns (e.g., ["inventory.alerts", "*.notifications"])
Returns:
List of full channel names to subscribe to
Examples:
>>> _get_subscription_channels("abc", ["inventory.alerts"])
["tenant:abc:inventory.alerts"]
>>> _get_subscription_channels("abc", ["*.alerts"])
["tenant:abc:inventory.alerts", "tenant:abc:production.alerts", ...]
>>> _get_subscription_channels("abc", [])
["tenant:abc:inventory.alerts", "tenant:abc:inventory.notifications", ...]
"""
"""Determine which Redis channels to subscribe to based on filters"""
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
all_classes = ["alerts", "notifications"]
channels = []
if not channel_filters:
@@ -271,70 +130,49 @@ def _get_subscription_channels(tenant_id: str, channel_filters: list) -> list:
for domain in all_domains:
for event_class in all_classes:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
# Also subscribe to recommendations (tenant-wide)
channels.append(f"tenant:{tenant_id}:recommendations")
# Also subscribe to legacy channel for backward compatibility
channels.append(f"alerts:{tenant_id}")
channels.append(f"alerts:{tenant_id}") # Legacy
return channels
# Parse filters and expand wildcards
for filter_pattern in channel_filters:
if filter_pattern == "*.*":
# All channels
for domain in all_domains:
for event_class in all_classes:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
channels.append(f"tenant:{tenant_id}:recommendations")
elif filter_pattern.endswith(".*"):
# Domain wildcard (e.g., "inventory.*")
domain = filter_pattern.split(".")[0]
for event_class in all_classes:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
elif filter_pattern.startswith("*."):
# Class wildcard (e.g., "*.alerts")
event_class = filter_pattern.split(".")[1]
if event_class == "recommendations":
channels.append(f"tenant:{tenant_id}:recommendations")
else:
for domain in all_domains:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
elif filter_pattern == "recommendations":
# Recommendations channel
channels.append(f"tenant:{tenant_id}:recommendations")
else:
# Specific channel (e.g., "inventory.alerts")
channels.append(f"tenant:{tenant_id}:{filter_pattern}")
return list(set(channels)) # Remove duplicates
return list(set(channels))
async def _load_initial_state(redis_client, tenant_id: str, channel_filters: list) -> list:
"""
Load initial state from Redis cache based on channel filters.
Args:
redis_client: Redis client
tenant_id: Tenant identifier
channel_filters: List of channel patterns
Returns:
List of initial events
"""
"""Load initial state from Redis cache based on channel filters"""
initial_events = []
try:
if not channel_filters:
# Load from legacy cache if no filters (backward compat)
# Legacy cache
legacy_cache_key = f"active_alerts:{tenant_id}"
cached_data = await redis_client.get(legacy_cache_key)
if cached_data:
return json.loads(cached_data)
# Also try loading from new domain-specific caches
# New domain-specific caches
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
all_classes = ["alerts", "notifications"]
@@ -343,10 +181,9 @@ async def _load_initial_state(redis_client, tenant_id: str, channel_filters: lis
cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s"
cached_data = await redis_client.get(cache_key)
if cached_data:
events = json.loads(cached_data)
initial_events.extend(events)
initial_events.extend(json.loads(cached_data))
# Load recommendations
# Recommendations
recommendations_cache_key = f"active_events:{tenant_id}:recommendations"
cached_data = await redis_client.get(recommendations_cache_key)
if cached_data:
@@ -356,36 +193,29 @@ async def _load_initial_state(redis_client, tenant_id: str, channel_filters: lis
# Load based on specific filters
for filter_pattern in channel_filters:
# Extract domain and class from filter
if "." in filter_pattern:
parts = filter_pattern.split(".")
domain = parts[0] if parts[0] != "*" else None
event_class = parts[1] if len(parts) > 1 and parts[1] != "*" else None
if domain and event_class:
# Specific cache (e.g., "inventory.alerts")
cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
elif domain and not event_class:
# Domain wildcard (e.g., "inventory.*")
for ec in ["alerts", "notifications"]:
cache_key = f"active_events:{tenant_id}:{domain}.{ec}"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
elif not domain and event_class:
# Class wildcard (e.g., "*.alerts")
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
for d in all_domains:
cache_key = f"active_events:{tenant_id}:{d}.{event_class}s"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
elif filter_pattern == "recommendations":
cache_key = f"active_events:{tenant_id}:recommendations"
cached_data = await redis_client.get(cache_key)
@@ -400,27 +230,14 @@ async def _load_initial_state(redis_client, tenant_id: str, channel_filters: lis
def _determine_event_type(event_data: dict) -> str:
"""
Determine SSE event type from event data.
Args:
event_data: Event data dictionary
Returns:
SSE event type: 'alert', 'notification', or 'recommendation'
"""
# New event architecture uses 'event_class'
"""Determine SSE event type from event data"""
if 'event_class' in event_data:
return event_data['event_class'] # 'alert', 'notification', or 'recommendation'
# Legacy format uses 'item_type'
return event_data['event_class']
if 'item_type' in event_data:
if event_data['item_type'] == 'recommendation':
return 'recommendation'
else:
return 'alert'
# Default to 'alert' for backward compatibility
return 'alert'
@@ -432,42 +249,25 @@ def _determine_event_type(event_data: dict) -> str:
async def events_stream(
request: Request,
tenant_id: str,
channels: str = None # Comma-separated channel filters (e.g., "inventory.alerts,production.notifications")
channels: str = None
):
"""
Server-Sent Events stream for real-time notifications with multi-channel support.
Authentication is handled by auth middleware via query param token.
User context is available in request.state.user (injected by middleware).
Query Parameters:
tenant_id: Tenant identifier (required)
channels: Comma-separated channel filters (optional)
Examples:
- "inventory.alerts,production.notifications" - Specific channels
- "*.alerts" - All alert channels
- "inventory.*" - All inventory events
- None - All channels (default, backward compatible)
New channel pattern: tenant:{tenant_id}:{domain}.{class}
Examples:
- tenant:abc:inventory.alerts
- tenant:abc:production.notifications
- tenant:abc:recommendations
Legacy channel (backward compat): alerts:{tenant_id}
"""
global redis_client
if not redis_client:
raise HTTPException(status_code=503, detail="SSE service unavailable")
# Extract user context from request state (set by auth middleware)
# Extract user context from request state
user_context = request.state.user
user_id = user_context.get('user_id')
email = user_context.get('email')
# Validate tenant_id parameter
if not tenant_id:
raise HTTPException(status_code=400, detail="tenant_id query parameter is required")
@@ -479,79 +279,53 @@ async def events_stream(
logger.info(f"SSE connection request for user {email}, tenant {tenant_id}, channels: {channel_filters or 'all'}")
async def event_generator():
"""Generate server-sent events from Redis pub/sub with multi-channel support"""
"""Generate server-sent events from Redis pub/sub"""
pubsub = None
try:
# Create pubsub connection with resource monitoring
pubsub = redis_client.pubsub()
logger.debug(f"Created Redis pubsub connection for tenant: {tenant_id}")
# Monitor connection count
try:
connection_info = await redis_client.info('clients')
connected_clients = connection_info.get('connected_clients', 'unknown')
logger.debug(f"Redis connected clients: {connected_clients}")
except Exception:
# Don't fail if we can't get connection info
pass
# Determine which channels to subscribe to
# Determine channels
subscription_channels = _get_subscription_channels(tenant_id, channel_filters)
# Subscribe to all determined channels
# Subscribe
if subscription_channels:
await pubsub.subscribe(*subscription_channels)
logger.info(f"Subscribed to {len(subscription_channels)} channels for tenant {tenant_id}")
else:
# Fallback to legacy channel if no channels specified
legacy_channel = f"alerts:{tenant_id}"
await pubsub.subscribe(legacy_channel)
logger.info(f"Subscribed to legacy channel: {legacy_channel}")
# Send initial connection event
# Connection event
yield f"event: connection\n"
yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'channels': subscription_channels or ['all'], 'timestamp': time.time()})}\n\n"
# Fetch and send initial state from cache (domain-specific or legacy)
# Initial state
initial_events = await _load_initial_state(redis_client, tenant_id, channel_filters)
if initial_events:
logger.info(f"Sending {len(initial_events)} initial events to tenant {tenant_id}")
yield f"event: initial_state\n"
yield f"data: {json.dumps(initial_events)}\n\n"
else:
# Send empty initial state for compatibility
yield f"event: initial_state\n"
yield f"data: {json.dumps([])}\n\n"
yield f"event: initial_state\n"
yield f"data: {json.dumps(initial_events)}\n\n"
heartbeat_counter = 0
while True:
# Check if client has disconnected
if await request.is_disconnected():
logger.info(f"SSE client disconnected for tenant: {tenant_id}")
break
try:
# Get message from Redis with timeout
message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0)
if message and message['type'] == 'message':
# Forward the event from Redis
event_data = json.loads(message['data'])
# Determine event type for SSE
event_type = _determine_event_type(event_data)
# Add channel metadata for frontend routing
event_data['_channel'] = message['channel'].decode('utf-8') if isinstance(message['channel'], bytes) else message['channel']
yield f"event: {event_type}\n"
yield f"data: {json.dumps(event_data)}\n\n"
logger.debug(f"SSE event sent to tenant {tenant_id}: {event_type} - {event_data.get('title')}")
except asyncio.TimeoutError:
# Send heartbeat every 10 timeouts (100 seconds)
heartbeat_counter += 1
if heartbeat_counter >= 10:
yield f"event: heartbeat\n"
@@ -563,24 +337,13 @@ async def events_stream(
except Exception as e:
logger.error(f"SSE error for tenant {tenant_id}: {e}", exc_info=True)
finally:
try:
if pubsub:
try:
# Unsubscribe from all channels
await pubsub.unsubscribe()
logger.debug(f"Unsubscribed from Redis channels for tenant: {tenant_id}")
except Exception as unsubscribe_error:
logger.error(f"Failed to unsubscribe Redis pubsub for tenant {tenant_id}: {unsubscribe_error}")
try:
# Close pubsub connection
await pubsub.close()
logger.debug(f"Closed Redis pubsub connection for tenant: {tenant_id}")
except Exception as close_error:
logger.error(f"Failed to close Redis pubsub for tenant {tenant_id}: {close_error}")
logger.info(f"SSE connection closed for tenant: {tenant_id}")
except Exception as finally_error:
logger.error(f"Error in SSE cleanup for tenant {tenant_id}: {finally_error}")
if pubsub:
try:
await pubsub.unsubscribe()
await pubsub.close()
except Exception as e:
logger.error(f"Error closing pubsub: {e}")
logger.info(f"SSE connection closed for tenant: {tenant_id}")
return StreamingResponse(
event_generator(),
@@ -593,55 +356,35 @@ async def events_stream(
}
)
# ================================================================
# WEBSOCKET ROUTING FOR TRAINING SERVICE
# ================================================================
@app.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
"""
Simple WebSocket proxy with token verification only.
Validates the token and forwards the connection to the training service.
"""
# Get token from query params
"""WebSocket proxy with token verification for training service"""
token = websocket.query_params.get("token")
if not token:
logger.warning("WebSocket proxy rejected - missing token",
job_id=job_id,
tenant_id=tenant_id)
await websocket.accept()
await websocket.close(code=1008, reason="Authentication token required")
return
# Verify token
from shared.auth.jwt_handler import JWTHandler
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload or not payload.get('user_id'):
logger.warning("WebSocket proxy rejected - invalid token",
job_id=job_id,
tenant_id=tenant_id)
await websocket.accept()
await websocket.close(code=1008, reason="Invalid token")
return
logger.info("WebSocket proxy - token verified",
user_id=payload.get('user_id'),
tenant_id=tenant_id,
job_id=job_id)
except Exception as e:
logger.warning("WebSocket proxy - token verification failed",
job_id=job_id,
error=str(e))
await websocket.accept()
await websocket.close(code=1008, reason="Token verification failed")
return
# Accept the connection
await websocket.accept()
# Build WebSocket URL to training service
@@ -649,33 +392,24 @@ async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
logger.info("Gateway proxying WebSocket to training service",
job_id=job_id,
training_ws_url=training_ws_url.replace(token, '***'))
training_ws = None
try:
# Connect to training service WebSocket
import websockets
from websockets.protocol import State
training_ws = await websockets.connect(
training_ws_url,
ping_interval=120, # Send ping every 2 minutes (tolerates long training operations)
ping_timeout=60, # Wait up to 1 minute for pong (graceful timeout)
close_timeout=60, # Increase close timeout for graceful shutdown
ping_interval=120,
ping_timeout=60,
close_timeout=60,
open_timeout=30
)
logger.info("Gateway connected to training service WebSocket", job_id=job_id)
async def forward_frontend_to_training():
"""Forward messages from frontend to training service"""
try:
while training_ws and training_ws.state == State.OPEN:
data = await websocket.receive()
if data.get("type") == "websocket.receive":
if "text" in data:
await training_ws.send(data["text"])
@@ -683,30 +417,17 @@ async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_
await training_ws.send(data["bytes"])
elif data.get("type") == "websocket.disconnect":
break
except Exception as e:
logger.debug("Frontend to training forward ended", error=str(e))
except Exception:
pass
async def forward_training_to_frontend():
"""Forward messages from training service to frontend"""
message_count = 0
try:
while training_ws and training_ws.state == State.OPEN:
message = await training_ws.recv()
await websocket.send_text(message)
message_count += 1
except Exception:
pass
# Log every 10th message to track connectivity
if message_count % 10 == 0:
logger.debug("WebSocket proxy active",
job_id=job_id,
messages_forwarded=message_count)
except Exception as e:
logger.info("Training to frontend forward ended",
job_id=job_id,
messages_forwarded=message_count,
error=str(e))
# Run both forwarding tasks concurrently
await asyncio.gather(
forward_frontend_to_training(),
forward_training_to_frontend(),
@@ -716,20 +437,17 @@ async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_
except Exception as e:
logger.error("WebSocket proxy error", job_id=job_id, error=str(e))
finally:
# Cleanup
if training_ws and training_ws.state == State.OPEN:
try:
await training_ws.close()
except:
pass
try:
if not websocket.client_state.name == 'DISCONNECTED':
await websocket.close(code=1000, reason="Proxy closed")
except:
pass
logger.info("WebSocket proxy connection closed", job_id=job_id)
if __name__ == "__main__":
import uvicorn