2025-07-17 13:09:24 +02:00
|
|
|
"""
|
|
|
|
|
API Gateway - Central entry point for all microservices
|
|
|
|
|
Handles routing, authentication, rate limiting, and cross-cutting concerns
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import asyncio
|
2025-10-02 13:20:30 +02:00
|
|
|
import json
|
2025-07-18 14:41:39 +02:00
|
|
|
import structlog
|
2025-12-14 19:05:37 +01:00
|
|
|
import resource
|
|
|
|
|
import os
|
2025-07-17 13:09:24 +02:00
|
|
|
import time
|
2026-01-09 23:14:12 +01:00
|
|
|
from fastapi import Request, HTTPException, WebSocket, WebSocketDisconnect
|
|
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
|
import httpx
|
2025-10-15 16:12:49 +02:00
|
|
|
from shared.redis_utils import initialize_redis, close_redis, get_redis_client
|
2026-01-09 23:14:12 +01:00
|
|
|
from shared.service_base import StandardFastAPIService
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
from app.core.config import settings
|
2025-10-15 16:12:49 +02:00
|
|
|
from app.middleware.request_id import RequestIDMiddleware
|
2025-07-17 19:54:04 +02:00
|
|
|
from app.middleware.auth import AuthMiddleware
|
|
|
|
|
from app.middleware.logging import LoggingMiddleware
|
|
|
|
|
from app.middleware.rate_limit import RateLimitMiddleware
|
2025-12-18 13:26:32 +01:00
|
|
|
from app.middleware.rate_limiting import APIRateLimitMiddleware
|
2025-09-21 13:27:50 +02:00
|
|
|
from app.middleware.subscription import SubscriptionMiddleware
|
2025-10-03 14:09:34 +02:00
|
|
|
from app.middleware.demo_middleware import DemoMiddleware
|
2025-10-16 07:28:04 +02:00
|
|
|
from app.middleware.read_only_mode import ReadOnlyModeMiddleware
|
2025-11-12 15:34:10 +01:00
|
|
|
from app.routes import auth, tenant, notification, nominatim, subscription, demo, pos, geocoding, poi_context
|
2026-01-08 12:58:00 +01:00
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
# Initialize logger
|
2025-07-18 14:41:39 +02:00
|
|
|
logger = structlog.get_logger()
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
# Check file descriptor limits
|
2025-12-14 19:05:37 +01:00
|
|
|
try:
|
|
|
|
|
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
|
|
|
|
if soft_limit < 1024:
|
2026-01-09 23:14:12 +01:00
|
|
|
logger.warning(f"Low file descriptor limit detected: {soft_limit}")
|
2025-12-14 19:05:37 +01:00
|
|
|
else:
|
|
|
|
|
logger.info(f"File descriptor limit: {soft_limit} (sufficient)")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.debug(f"Could not check file descriptor limits: {e}")
|
|
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
# Redis client for SSE streaming
|
|
|
|
|
redis_client = None
|
2025-12-14 19:05:37 +01:00
|
|
|
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
class GatewayService(StandardFastAPIService):
|
|
|
|
|
"""Gateway Service with standardized monitoring setup"""
|
2026-01-08 12:58:00 +01:00
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
async def on_startup(self, app):
|
|
|
|
|
"""Custom startup logic for Gateway"""
|
|
|
|
|
global redis_client
|
2026-01-08 12:58:00 +01:00
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
# Initialize Redis
|
|
|
|
|
try:
|
|
|
|
|
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
|
|
|
|
|
redis_client = await get_redis_client()
|
|
|
|
|
logger.info("Connected to Redis for SSE streaming")
|
2026-01-08 12:58:00 +01:00
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
# Add API rate limiting middleware with Redis client
|
|
|
|
|
app.add_middleware(APIRateLimitMiddleware, redis_client=redis_client)
|
|
|
|
|
logger.info("API rate limiting middleware enabled")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to connect to Redis: {e}")
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
# Register custom metrics for gateway-specific operations
|
|
|
|
|
if self.telemetry_providers and self.telemetry_providers.app_metrics:
|
|
|
|
|
logger.info("Gateway-specific metrics tracking enabled")
|
2026-01-08 12:58:00 +01:00
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
await super().on_startup(app)
|
|
|
|
|
|
|
|
|
|
async def on_shutdown(self, app):
|
|
|
|
|
"""Custom shutdown logic for Gateway"""
|
|
|
|
|
await super().on_shutdown(app)
|
|
|
|
|
|
|
|
|
|
# Close Redis
|
|
|
|
|
await close_redis()
|
|
|
|
|
logger.info("Redis connection closed")
|
2025-09-04 23:19:53 +02:00
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
|
|
|
|
|
# Create service instance
|
|
|
|
|
service = GatewayService(
|
|
|
|
|
service_name="gateway",
|
|
|
|
|
app_name="Bakery Forecasting API Gateway",
|
|
|
|
|
description="Central API Gateway for bakery forecasting microservices",
|
|
|
|
|
version="1.0.0",
|
|
|
|
|
log_level=getattr(settings, 'LOG_LEVEL', 'INFO'),
|
|
|
|
|
cors_origins=settings.CORS_ORIGINS_LIST,
|
|
|
|
|
enable_metrics=True,
|
|
|
|
|
enable_health_checks=True,
|
|
|
|
|
enable_tracing=True,
|
|
|
|
|
enable_cors=True
|
2025-07-17 13:09:24 +02:00
|
|
|
)
|
|
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
# Create FastAPI app
|
|
|
|
|
app = service.create_app()
|
|
|
|
|
|
|
|
|
|
# Add gateway-specific middleware (in REVERSE order of execution)
|
2025-12-18 13:26:32 +01:00
|
|
|
# Execution order: RequestIDMiddleware -> DemoMiddleware -> AuthMiddleware -> ReadOnlyModeMiddleware -> SubscriptionMiddleware -> APIRateLimitMiddleware -> RateLimitMiddleware -> LoggingMiddleware
|
2026-01-09 23:14:12 +01:00
|
|
|
app.add_middleware(LoggingMiddleware)
|
|
|
|
|
app.add_middleware(RateLimitMiddleware, calls_per_minute=300)
|
|
|
|
|
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL)
|
|
|
|
|
app.add_middleware(ReadOnlyModeMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL)
|
|
|
|
|
app.add_middleware(AuthMiddleware)
|
|
|
|
|
app.add_middleware(DemoMiddleware)
|
|
|
|
|
app.add_middleware(RequestIDMiddleware)
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
# Include routers
|
|
|
|
|
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
|
|
|
|
|
app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"])
|
2025-09-21 13:27:50 +02:00
|
|
|
app.include_router(subscription.router, prefix="/api/v1", tags=["subscriptions"])
|
2025-07-17 13:09:24 +02:00
|
|
|
app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"])
|
2025-07-22 17:01:12 +02:00
|
|
|
app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"])
|
2025-11-12 15:34:10 +01:00
|
|
|
app.include_router(geocoding.router, prefix="/api/v1/geocoding", tags=["geocoding"])
|
2025-10-07 07:15:07 +02:00
|
|
|
app.include_router(pos.router, prefix="/api/v1/pos", tags=["pos"])
|
2025-10-03 14:09:34 +02:00
|
|
|
app.include_router(demo.router, prefix="/api/v1", tags=["demo"])
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-09-09 12:02:41 +02:00
|
|
|
|
2025-11-27 15:52:40 +01:00
|
|
|
# ================================================================
|
|
|
|
|
# SERVER-SENT EVENTS (SSE) HELPER FUNCTIONS
|
|
|
|
|
# ================================================================
|
|
|
|
|
|
|
|
|
|
def _get_subscription_channels(tenant_id: str, channel_filters: list) -> list:
|
2026-01-09 23:14:12 +01:00
|
|
|
"""Determine which Redis channels to subscribe to based on filters"""
|
2025-11-27 15:52:40 +01:00
|
|
|
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
|
|
|
|
|
all_classes = ["alerts", "notifications"]
|
|
|
|
|
channels = []
|
|
|
|
|
|
|
|
|
|
if not channel_filters:
|
|
|
|
|
# Subscribe to ALL channels (backward compatible)
|
|
|
|
|
for domain in all_domains:
|
|
|
|
|
for event_class in all_classes:
|
|
|
|
|
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
|
|
|
|
|
channels.append(f"tenant:{tenant_id}:recommendations")
|
2026-01-09 23:14:12 +01:00
|
|
|
channels.append(f"alerts:{tenant_id}") # Legacy
|
2025-11-27 15:52:40 +01:00
|
|
|
return channels
|
|
|
|
|
|
|
|
|
|
# Parse filters and expand wildcards
|
|
|
|
|
for filter_pattern in channel_filters:
|
|
|
|
|
if filter_pattern == "*.*":
|
|
|
|
|
for domain in all_domains:
|
|
|
|
|
for event_class in all_classes:
|
|
|
|
|
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
|
|
|
|
|
channels.append(f"tenant:{tenant_id}:recommendations")
|
|
|
|
|
elif filter_pattern.endswith(".*"):
|
|
|
|
|
domain = filter_pattern.split(".")[0]
|
|
|
|
|
for event_class in all_classes:
|
|
|
|
|
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
|
|
|
|
|
elif filter_pattern.startswith("*."):
|
|
|
|
|
event_class = filter_pattern.split(".")[1]
|
|
|
|
|
if event_class == "recommendations":
|
|
|
|
|
channels.append(f"tenant:{tenant_id}:recommendations")
|
|
|
|
|
else:
|
|
|
|
|
for domain in all_domains:
|
|
|
|
|
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
|
|
|
|
|
elif filter_pattern == "recommendations":
|
|
|
|
|
channels.append(f"tenant:{tenant_id}:recommendations")
|
|
|
|
|
else:
|
|
|
|
|
channels.append(f"tenant:{tenant_id}:{filter_pattern}")
|
|
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
return list(set(channels))
|
2025-11-27 15:52:40 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _load_initial_state(redis_client, tenant_id: str, channel_filters: list) -> list:
|
2026-01-09 23:14:12 +01:00
|
|
|
"""Load initial state from Redis cache based on channel filters"""
|
2025-11-27 15:52:40 +01:00
|
|
|
initial_events = []
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if not channel_filters:
|
2026-01-09 23:14:12 +01:00
|
|
|
# Legacy cache
|
2025-11-27 15:52:40 +01:00
|
|
|
legacy_cache_key = f"active_alerts:{tenant_id}"
|
|
|
|
|
cached_data = await redis_client.get(legacy_cache_key)
|
|
|
|
|
if cached_data:
|
|
|
|
|
return json.loads(cached_data)
|
|
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
# New domain-specific caches
|
2025-11-27 15:52:40 +01:00
|
|
|
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
|
|
|
|
|
all_classes = ["alerts", "notifications"]
|
|
|
|
|
|
|
|
|
|
for domain in all_domains:
|
|
|
|
|
for event_class in all_classes:
|
|
|
|
|
cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s"
|
|
|
|
|
cached_data = await redis_client.get(cache_key)
|
|
|
|
|
if cached_data:
|
2026-01-09 23:14:12 +01:00
|
|
|
initial_events.extend(json.loads(cached_data))
|
2025-11-27 15:52:40 +01:00
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
# Recommendations
|
2025-11-27 15:52:40 +01:00
|
|
|
recommendations_cache_key = f"active_events:{tenant_id}:recommendations"
|
|
|
|
|
cached_data = await redis_client.get(recommendations_cache_key)
|
|
|
|
|
if cached_data:
|
|
|
|
|
initial_events.extend(json.loads(cached_data))
|
|
|
|
|
|
|
|
|
|
return initial_events
|
|
|
|
|
|
|
|
|
|
# Load based on specific filters
|
|
|
|
|
for filter_pattern in channel_filters:
|
|
|
|
|
if "." in filter_pattern:
|
|
|
|
|
parts = filter_pattern.split(".")
|
|
|
|
|
domain = parts[0] if parts[0] != "*" else None
|
|
|
|
|
event_class = parts[1] if len(parts) > 1 and parts[1] != "*" else None
|
|
|
|
|
|
|
|
|
|
if domain and event_class:
|
|
|
|
|
cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s"
|
|
|
|
|
cached_data = await redis_client.get(cache_key)
|
|
|
|
|
if cached_data:
|
|
|
|
|
initial_events.extend(json.loads(cached_data))
|
|
|
|
|
elif domain and not event_class:
|
|
|
|
|
for ec in ["alerts", "notifications"]:
|
|
|
|
|
cache_key = f"active_events:{tenant_id}:{domain}.{ec}"
|
|
|
|
|
cached_data = await redis_client.get(cache_key)
|
|
|
|
|
if cached_data:
|
|
|
|
|
initial_events.extend(json.loads(cached_data))
|
|
|
|
|
elif not domain and event_class:
|
|
|
|
|
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
|
|
|
|
|
for d in all_domains:
|
|
|
|
|
cache_key = f"active_events:{tenant_id}:{d}.{event_class}s"
|
|
|
|
|
cached_data = await redis_client.get(cache_key)
|
|
|
|
|
if cached_data:
|
|
|
|
|
initial_events.extend(json.loads(cached_data))
|
|
|
|
|
elif filter_pattern == "recommendations":
|
|
|
|
|
cache_key = f"active_events:{tenant_id}:recommendations"
|
|
|
|
|
cached_data = await redis_client.get(cache_key)
|
|
|
|
|
if cached_data:
|
|
|
|
|
initial_events.extend(json.loads(cached_data))
|
|
|
|
|
|
|
|
|
|
return initial_events
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error loading initial state for tenant {tenant_id}: {e}")
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _determine_event_type(event_data: dict) -> str:
|
2026-01-09 23:14:12 +01:00
|
|
|
"""Determine SSE event type from event data"""
|
2025-11-27 15:52:40 +01:00
|
|
|
if 'event_class' in event_data:
|
2026-01-09 23:14:12 +01:00
|
|
|
return event_data['event_class']
|
2025-11-27 15:52:40 +01:00
|
|
|
if 'item_type' in event_data:
|
|
|
|
|
if event_data['item_type'] == 'recommendation':
|
|
|
|
|
return 'recommendation'
|
|
|
|
|
else:
|
|
|
|
|
return 'alert'
|
|
|
|
|
return 'alert'
|
|
|
|
|
|
|
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
# ================================================================
|
|
|
|
|
# SERVER-SENT EVENTS (SSE) ENDPOINT
|
|
|
|
|
# ================================================================
|
|
|
|
|
|
|
|
|
|
@app.get("/api/events")
|
2025-11-27 15:52:40 +01:00
|
|
|
async def events_stream(
|
|
|
|
|
request: Request,
|
|
|
|
|
tenant_id: str,
|
2026-01-09 23:14:12 +01:00
|
|
|
channels: str = None
|
2025-11-27 15:52:40 +01:00
|
|
|
):
|
2025-10-02 13:20:30 +02:00
|
|
|
"""
|
2025-11-27 15:52:40 +01:00
|
|
|
Server-Sent Events stream for real-time notifications with multi-channel support.
|
2025-10-02 13:20:30 +02:00
|
|
|
|
2025-11-27 15:52:40 +01:00
|
|
|
Query Parameters:
|
|
|
|
|
tenant_id: Tenant identifier (required)
|
|
|
|
|
channels: Comma-separated channel filters (optional)
|
2025-10-02 13:20:30 +02:00
|
|
|
"""
|
2025-09-04 23:19:53 +02:00
|
|
|
global redis_client
|
2025-10-02 13:20:30 +02:00
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
if not redis_client:
|
|
|
|
|
raise HTTPException(status_code=503, detail="SSE service unavailable")
|
2025-10-02 13:20:30 +02:00
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
# Extract user context from request state
|
2025-10-02 13:20:30 +02:00
|
|
|
user_context = request.state.user
|
|
|
|
|
user_id = user_context.get('user_id')
|
|
|
|
|
email = user_context.get('email')
|
|
|
|
|
|
|
|
|
|
if not tenant_id:
|
|
|
|
|
raise HTTPException(status_code=400, detail="tenant_id query parameter is required")
|
|
|
|
|
|
2025-11-27 15:52:40 +01:00
|
|
|
# Parse channel filters
|
|
|
|
|
channel_filters = []
|
|
|
|
|
if channels:
|
|
|
|
|
channel_filters = [c.strip() for c in channels.split(',') if c.strip()]
|
|
|
|
|
|
|
|
|
|
logger.info(f"SSE connection request for user {email}, tenant {tenant_id}, channels: {channel_filters or 'all'}")
|
|
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
async def event_generator():
|
2026-01-09 23:14:12 +01:00
|
|
|
"""Generate server-sent events from Redis pub/sub"""
|
2025-09-04 23:19:53 +02:00
|
|
|
pubsub = None
|
|
|
|
|
try:
|
|
|
|
|
pubsub = redis_client.pubsub()
|
2025-12-14 19:05:37 +01:00
|
|
|
logger.debug(f"Created Redis pubsub connection for tenant: {tenant_id}")
|
2025-11-27 15:52:40 +01:00
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
# Determine channels
|
2025-11-27 15:52:40 +01:00
|
|
|
subscription_channels = _get_subscription_channels(tenant_id, channel_filters)
|
|
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
# Subscribe
|
2025-11-27 15:52:40 +01:00
|
|
|
if subscription_channels:
|
|
|
|
|
await pubsub.subscribe(*subscription_channels)
|
|
|
|
|
logger.info(f"Subscribed to {len(subscription_channels)} channels for tenant {tenant_id}")
|
|
|
|
|
else:
|
|
|
|
|
legacy_channel = f"alerts:{tenant_id}"
|
|
|
|
|
await pubsub.subscribe(legacy_channel)
|
2025-10-19 19:22:37 +02:00
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
# Connection event
|
2025-09-04 23:19:53 +02:00
|
|
|
yield f"event: connection\n"
|
2025-11-27 15:52:40 +01:00
|
|
|
yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'channels': subscription_channels or ['all'], 'timestamp': time.time()})}\n\n"
|
|
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
# Initial state
|
2025-11-27 15:52:40 +01:00
|
|
|
initial_events = await _load_initial_state(redis_client, tenant_id, channel_filters)
|
|
|
|
|
if initial_events:
|
|
|
|
|
logger.info(f"Sending {len(initial_events)} initial events to tenant {tenant_id}")
|
2026-01-09 23:14:12 +01:00
|
|
|
yield f"event: initial_state\n"
|
|
|
|
|
yield f"data: {json.dumps(initial_events)}\n\n"
|
2025-10-19 19:22:37 +02:00
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
heartbeat_counter = 0
|
2025-11-27 15:52:40 +01:00
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
while True:
|
|
|
|
|
if await request.is_disconnected():
|
|
|
|
|
logger.info(f"SSE client disconnected for tenant: {tenant_id}")
|
|
|
|
|
break
|
2025-11-27 15:52:40 +01:00
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
try:
|
|
|
|
|
message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0)
|
2025-11-27 15:52:40 +01:00
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
if message and message['type'] == 'message':
|
2025-11-27 15:52:40 +01:00
|
|
|
event_data = json.loads(message['data'])
|
|
|
|
|
event_type = _determine_event_type(event_data)
|
|
|
|
|
event_data['_channel'] = message['channel'].decode('utf-8') if isinstance(message['channel'], bytes) else message['channel']
|
2025-10-02 13:20:30 +02:00
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
yield f"event: {event_type}\n"
|
2025-11-27 15:52:40 +01:00
|
|
|
yield f"data: {json.dumps(event_data)}\n\n"
|
|
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
heartbeat_counter += 1
|
|
|
|
|
if heartbeat_counter >= 10:
|
|
|
|
|
yield f"event: heartbeat\n"
|
2025-10-02 13:20:30 +02:00
|
|
|
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n"
|
2025-09-04 23:19:53 +02:00
|
|
|
heartbeat_counter = 0
|
2025-11-27 15:52:40 +01:00
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
logger.info(f"SSE connection cancelled for tenant: {tenant_id}")
|
|
|
|
|
except Exception as e:
|
2025-11-27 15:52:40 +01:00
|
|
|
logger.error(f"SSE error for tenant {tenant_id}: {e}", exc_info=True)
|
2025-09-04 23:19:53 +02:00
|
|
|
finally:
|
2026-01-09 23:14:12 +01:00
|
|
|
if pubsub:
|
|
|
|
|
try:
|
|
|
|
|
await pubsub.unsubscribe()
|
|
|
|
|
await pubsub.close()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error closing pubsub: {e}")
|
|
|
|
|
logger.info(f"SSE connection closed for tenant: {tenant_id}")
|
2025-11-27 15:52:40 +01:00
|
|
|
|
2025-09-04 23:19:53 +02:00
|
|
|
return StreamingResponse(
|
|
|
|
|
event_generator(),
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
headers={
|
|
|
|
|
"Cache-Control": "no-cache",
|
|
|
|
|
"Connection": "keep-alive",
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
"Access-Control-Allow-Headers": "Cache-Control",
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
2026-01-09 23:14:12 +01:00
|
|
|
|
2025-08-14 13:26:59 +02:00
|
|
|
# ================================================================
|
|
|
|
|
# WEBSOCKET ROUTING FOR TRAINING SERVICE
|
|
|
|
|
# ================================================================
|
|
|
|
|
|
2025-10-07 07:15:07 +02:00
|
|
|
@app.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
2025-08-14 13:26:59 +02:00
|
|
|
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
|
2026-01-09 23:14:12 +01:00
|
|
|
"""WebSocket proxy with token verification for training service"""
|
2025-08-14 13:26:59 +02:00
|
|
|
token = websocket.query_params.get("token")
|
|
|
|
|
if not token:
|
2025-10-07 07:15:07 +02:00
|
|
|
await websocket.accept()
|
2025-08-14 13:26:59 +02:00
|
|
|
await websocket.close(code=1008, reason="Authentication token required")
|
|
|
|
|
return
|
2025-09-29 07:54:25 +02:00
|
|
|
|
2025-10-09 14:11:02 +02:00
|
|
|
# Verify token
|
|
|
|
|
from shared.auth.jwt_handler import JWTHandler
|
|
|
|
|
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
payload = jwt_handler.verify_token(token)
|
|
|
|
|
if not payload or not payload.get('user_id'):
|
|
|
|
|
await websocket.accept()
|
|
|
|
|
await websocket.close(code=1008, reason="Invalid token")
|
|
|
|
|
return
|
|
|
|
|
except Exception as e:
|
|
|
|
|
await websocket.accept()
|
|
|
|
|
await websocket.close(code=1008, reason="Token verification failed")
|
|
|
|
|
return
|
2025-09-29 07:54:25 +02:00
|
|
|
|
2025-10-09 14:11:02 +02:00
|
|
|
await websocket.accept()
|
|
|
|
|
|
|
|
|
|
# Build WebSocket URL to training service
|
2025-08-14 13:26:59 +02:00
|
|
|
training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/')
|
2025-08-15 17:53:59 +02:00
|
|
|
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
|
2025-10-06 15:27:01 +02:00
|
|
|
training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
|
2025-09-29 07:54:25 +02:00
|
|
|
|
|
|
|
|
training_ws = None
|
|
|
|
|
|
2025-08-14 13:26:59 +02:00
|
|
|
try:
|
2025-08-15 17:53:59 +02:00
|
|
|
import websockets
|
2025-10-17 23:09:40 +02:00
|
|
|
from websockets.protocol import State
|
2025-09-29 07:54:25 +02:00
|
|
|
|
|
|
|
|
training_ws = await websockets.connect(
|
|
|
|
|
training_ws_url,
|
2026-01-09 23:14:12 +01:00
|
|
|
ping_interval=120,
|
|
|
|
|
ping_timeout=60,
|
|
|
|
|
close_timeout=60,
|
2025-10-09 14:11:02 +02:00
|
|
|
open_timeout=30
|
2025-09-29 07:54:25 +02:00
|
|
|
)
|
|
|
|
|
|
2025-10-09 14:11:02 +02:00
|
|
|
async def forward_frontend_to_training():
|
2025-09-29 07:54:25 +02:00
|
|
|
try:
|
2025-10-17 23:09:40 +02:00
|
|
|
while training_ws and training_ws.state == State.OPEN:
|
2025-10-07 07:15:07 +02:00
|
|
|
data = await websocket.receive()
|
|
|
|
|
if data.get("type") == "websocket.receive":
|
|
|
|
|
if "text" in data:
|
|
|
|
|
await training_ws.send(data["text"])
|
|
|
|
|
elif "bytes" in data:
|
|
|
|
|
await training_ws.send(data["bytes"])
|
|
|
|
|
elif data.get("type") == "websocket.disconnect":
|
2025-09-29 07:54:25 +02:00
|
|
|
break
|
2026-01-09 23:14:12 +01:00
|
|
|
except Exception:
|
|
|
|
|
pass
|
2025-09-29 07:54:25 +02:00
|
|
|
|
2025-10-09 14:11:02 +02:00
|
|
|
async def forward_training_to_frontend():
|
2025-09-29 07:54:25 +02:00
|
|
|
try:
|
2025-10-17 23:09:40 +02:00
|
|
|
while training_ws and training_ws.state == State.OPEN:
|
2025-10-07 07:15:07 +02:00
|
|
|
message = await training_ws.recv()
|
|
|
|
|
await websocket.send_text(message)
|
2026-01-09 23:14:12 +01:00
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
|
2025-10-07 07:15:07 +02:00
|
|
|
await asyncio.gather(
|
2025-10-09 14:11:02 +02:00
|
|
|
forward_frontend_to_training(),
|
|
|
|
|
forward_training_to_frontend(),
|
2025-10-07 07:15:07 +02:00
|
|
|
return_exceptions=True
|
|
|
|
|
)
|
2025-09-29 07:54:25 +02:00
|
|
|
|
2025-08-14 13:26:59 +02:00
|
|
|
except Exception as e:
|
2025-10-09 14:11:02 +02:00
|
|
|
logger.error("WebSocket proxy error", job_id=job_id, error=str(e))
|
2025-08-15 17:53:59 +02:00
|
|
|
finally:
|
2025-10-17 23:09:40 +02:00
|
|
|
if training_ws and training_ws.state == State.OPEN:
|
2025-09-29 07:54:25 +02:00
|
|
|
try:
|
|
|
|
|
await training_ws.close()
|
2025-10-09 14:11:02 +02:00
|
|
|
except:
|
|
|
|
|
pass
|
2025-09-29 07:54:25 +02:00
|
|
|
try:
|
|
|
|
|
if not websocket.client_state.name == 'DISCONNECTED':
|
2025-10-07 07:15:07 +02:00
|
|
|
await websocket.close(code=1000, reason="Proxy closed")
|
2025-10-09 14:11:02 +02:00
|
|
|
except:
|
|
|
|
|
pass
|
2025-09-29 07:54:25 +02:00
|
|
|
|
2025-08-14 13:26:59 +02:00
|
|
|
|
2025-07-17 13:09:24 +02:00
|
|
|
if __name__ == "__main__":
|
|
|
|
|
import uvicorn
|
2025-10-09 14:11:02 +02:00
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|