Imporve monitoring 5
This commit is contained in:
@@ -8,13 +8,12 @@ import json
|
||||
import structlog
|
||||
import resource
|
||||
import os
|
||||
from fastapi import FastAPI, Request, HTTPException, Depends, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
||||
import httpx
|
||||
import time
|
||||
from fastapi import Request, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import StreamingResponse
|
||||
import httpx
|
||||
from shared.redis_utils import initialize_redis, close_redis, get_redis_client
|
||||
from typing import Dict, Any
|
||||
from shared.service_base import StandardFastAPIService
|
||||
|
||||
from app.core.config import settings
|
||||
from app.middleware.request_id import RequestIDMiddleware
|
||||
@@ -26,128 +25,84 @@ from app.middleware.subscription import SubscriptionMiddleware
|
||||
from app.middleware.demo_middleware import DemoMiddleware
|
||||
from app.middleware.read_only_mode import ReadOnlyModeMiddleware
|
||||
from app.routes import auth, tenant, notification, nominatim, subscription, demo, pos, geocoding, poi_context
|
||||
from shared.monitoring.logging import setup_logging
|
||||
from shared.monitoring.metrics import MetricsCollector, add_metrics_middleware
|
||||
from shared.monitoring.system_metrics import SystemMetricsCollector
|
||||
|
||||
# OpenTelemetry imports
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
||||
from opentelemetry.instrumentation.redis import RedisInstrumentor
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
|
||||
# Configure OpenTelemetry tracing
|
||||
def setup_tracing(service_name: str = "gateway"):
|
||||
"""Initialize OpenTelemetry tracing with OTLP exporter for Jaeger"""
|
||||
# Create resource with service name
|
||||
resource = Resource.create({"service.name": service_name})
|
||||
|
||||
# Configure OTLP exporter (sends to OpenTelemetry Collector)
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "http://signoz-otel-collector.bakery-ia.svc.cluster.local:4317"),
|
||||
insecure=True # Use insecure connection for internal cluster communication
|
||||
)
|
||||
|
||||
# Configure tracer provider
|
||||
provider = TracerProvider(resource=resource)
|
||||
processor = BatchSpanProcessor(otlp_exporter)
|
||||
provider.add_span_processor(processor)
|
||||
|
||||
# Set global tracer provider
|
||||
trace.set_tracer_provider(provider)
|
||||
|
||||
return provider
|
||||
|
||||
# Initialize tracing
|
||||
tracer_provider = setup_tracing("gateway")
|
||||
|
||||
# Setup logging
|
||||
setup_logging("gateway", settings.LOG_LEVEL)
|
||||
# Initialize logger
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Check file descriptor limits and warn if too low
|
||||
# Check file descriptor limits
|
||||
try:
|
||||
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
if soft_limit < 1024:
|
||||
logger.warning(f"Low file descriptor limit detected: {soft_limit}. Gateway may experience 'too many open files' errors.")
|
||||
logger.warning(f"Recommended: Increase limit with 'ulimit -n 4096' or higher for production.")
|
||||
if soft_limit < 256:
|
||||
logger.error(f"Critical: File descriptor limit ({soft_limit}) is too low for gateway operation!")
|
||||
logger.warning(f"Low file descriptor limit detected: {soft_limit}")
|
||||
else:
|
||||
logger.info(f"File descriptor limit: {soft_limit} (sufficient)")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not check file descriptor limits: {e}")
|
||||
|
||||
# Check and log current working directory and permissions
|
||||
try:
|
||||
cwd = os.getcwd()
|
||||
logger.info(f"Current working directory: {cwd}")
|
||||
|
||||
# Check if we can write to common log locations
|
||||
test_locations = ["/var/log", "./logs", "."]
|
||||
for location in test_locations:
|
||||
try:
|
||||
test_file = os.path.join(location, ".gateway_permission_test")
|
||||
with open(test_file, 'w') as f:
|
||||
f.write("test")
|
||||
os.remove(test_file)
|
||||
logger.info(f"Write permission confirmed for: {location}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot write to {location}: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not check directory permissions: {e}")
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Bakery Forecasting API Gateway",
|
||||
description="Central API Gateway for bakery forecasting microservices",
|
||||
version="1.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
redirect_slashes=False # Disable automatic trailing slash redirects
|
||||
)
|
||||
|
||||
# Instrument FastAPI with OpenTelemetry
|
||||
FastAPIInstrumentor.instrument_app(app)
|
||||
|
||||
# Instrument httpx for outgoing requests
|
||||
HTTPXClientInstrumentor().instrument()
|
||||
|
||||
# Instrument Redis (will be active once redis client is initialized)
|
||||
RedisInstrumentor().instrument()
|
||||
|
||||
# Initialize metrics collector
|
||||
metrics_collector = MetricsCollector("gateway")
|
||||
|
||||
# Add metrics middleware to track HTTP requests
|
||||
add_metrics_middleware(app, metrics_collector)
|
||||
|
||||
# Redis client for SSE streaming
|
||||
redis_client = None
|
||||
|
||||
# CORS middleware - Add first
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS_LIST,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
class GatewayService(StandardFastAPIService):
|
||||
"""Gateway Service with standardized monitoring setup"""
|
||||
|
||||
async def on_startup(self, app):
|
||||
"""Custom startup logic for Gateway"""
|
||||
global redis_client
|
||||
|
||||
# Initialize Redis
|
||||
try:
|
||||
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
|
||||
redis_client = await get_redis_client()
|
||||
logger.info("Connected to Redis for SSE streaming")
|
||||
|
||||
# Add API rate limiting middleware with Redis client
|
||||
app.add_middleware(APIRateLimitMiddleware, redis_client=redis_client)
|
||||
logger.info("API rate limiting middleware enabled")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}")
|
||||
|
||||
# Register custom metrics for gateway-specific operations
|
||||
if self.telemetry_providers and self.telemetry_providers.app_metrics:
|
||||
logger.info("Gateway-specific metrics tracking enabled")
|
||||
|
||||
await super().on_startup(app)
|
||||
|
||||
async def on_shutdown(self, app):
|
||||
"""Custom shutdown logic for Gateway"""
|
||||
await super().on_shutdown(app)
|
||||
|
||||
# Close Redis
|
||||
await close_redis()
|
||||
logger.info("Redis connection closed")
|
||||
|
||||
|
||||
# Create service instance
|
||||
service = GatewayService(
|
||||
service_name="gateway",
|
||||
app_name="Bakery Forecasting API Gateway",
|
||||
description="Central API Gateway for bakery forecasting microservices",
|
||||
version="1.0.0",
|
||||
log_level=getattr(settings, 'LOG_LEVEL', 'INFO'),
|
||||
cors_origins=settings.CORS_ORIGINS_LIST,
|
||||
enable_metrics=True,
|
||||
enable_health_checks=True,
|
||||
enable_tracing=True,
|
||||
enable_cors=True
|
||||
)
|
||||
|
||||
# Custom middleware - Add in REVERSE order (last added = first executed)
|
||||
# Create FastAPI app
|
||||
app = service.create_app()
|
||||
|
||||
# Add gateway-specific middleware (in REVERSE order of execution)
|
||||
# Execution order: RequestIDMiddleware -> DemoMiddleware -> AuthMiddleware -> ReadOnlyModeMiddleware -> SubscriptionMiddleware -> APIRateLimitMiddleware -> RateLimitMiddleware -> LoggingMiddleware
|
||||
app.add_middleware(LoggingMiddleware) # Executes 8th (outermost)
|
||||
app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 7th - Simple rate limit
|
||||
# Note: APIRateLimitMiddleware will be added on startup with Redis client
|
||||
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 5th
|
||||
app.add_middleware(ReadOnlyModeMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 4th - Enforce read-only mode
|
||||
app.add_middleware(AuthMiddleware) # Executes 3rd - Checks for demo context
|
||||
app.add_middleware(DemoMiddleware) # Executes 2nd - Sets demo user context
|
||||
app.add_middleware(RequestIDMiddleware) # Executes 1st (innermost) - Generates request ID for tracing
|
||||
app.add_middleware(LoggingMiddleware)
|
||||
app.add_middleware(RateLimitMiddleware, calls_per_minute=300)
|
||||
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL)
|
||||
app.add_middleware(ReadOnlyModeMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL)
|
||||
app.add_middleware(AuthMiddleware)
|
||||
app.add_middleware(DemoMiddleware)
|
||||
app.add_middleware(RequestIDMiddleware)
|
||||
|
||||
# Include routers
|
||||
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
|
||||
@@ -156,114 +111,18 @@ app.include_router(subscription.router, prefix="/api/v1", tags=["subscriptions"]
|
||||
app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"])
|
||||
app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"])
|
||||
app.include_router(geocoding.router, prefix="/api/v1/geocoding", tags=["geocoding"])
|
||||
# app.include_router(poi_context.router, prefix="/api/v1/poi-context", tags=["poi-context"]) # Removed to implement tenant-based architecture
|
||||
app.include_router(pos.router, prefix="/api/v1/pos", tags=["pos"])
|
||||
app.include_router(demo.router, prefix="/api/v1", tags=["demo"])
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Application startup"""
|
||||
global redis_client
|
||||
|
||||
logger.info("Starting API Gateway")
|
||||
|
||||
# Initialize shared Redis connection
|
||||
try:
|
||||
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
|
||||
redis_client = await get_redis_client()
|
||||
logger.info("Connected to Redis for SSE streaming")
|
||||
|
||||
# Add API rate limiting middleware with Redis client
|
||||
app.add_middleware(APIRateLimitMiddleware, redis_client=redis_client)
|
||||
logger.info("API rate limiting middleware enabled with subscription-based quotas")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}")
|
||||
logger.warning("API rate limiting middleware will fail open (allow all requests)")
|
||||
|
||||
metrics_collector.register_counter(
|
||||
"gateway_auth_requests_total",
|
||||
"Total authentication requests"
|
||||
)
|
||||
metrics_collector.register_counter(
|
||||
"gateway_auth_responses_total",
|
||||
"Total authentication responses"
|
||||
)
|
||||
metrics_collector.register_counter(
|
||||
"gateway_auth_errors_total",
|
||||
"Total authentication errors"
|
||||
)
|
||||
|
||||
metrics_collector.register_histogram(
|
||||
"gateway_request_duration_seconds",
|
||||
"Request duration in seconds"
|
||||
)
|
||||
|
||||
logger.info("Metrics registered successfully")
|
||||
|
||||
# Note: Metrics are exported via OpenTelemetry OTLP to SigNoz - no metrics server needed
|
||||
# Initialize system metrics collection
|
||||
system_metrics = SystemMetricsCollector("gateway")
|
||||
logger.info("System metrics collection started")
|
||||
|
||||
logger.info("Metrics export configured via OpenTelemetry OTLP")
|
||||
|
||||
logger.info("API Gateway started successfully")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Application shutdown"""
|
||||
logger.info("Shutting down API Gateway")
|
||||
|
||||
# Close shared Redis connection
|
||||
await close_redis()
|
||||
|
||||
# Clean up service discovery
|
||||
# await service_discovery.cleanup()
|
||||
|
||||
logger.info("API Gateway shutdown complete")
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "api-gateway",
|
||||
"version": "1.0.0",
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
# Note: Metrics are exported via OpenTelemetry OTLP to SigNoz
|
||||
# The /metrics endpoint is not needed as metrics are pushed automatically
|
||||
|
||||
# ================================================================
|
||||
# SERVER-SENT EVENTS (SSE) HELPER FUNCTIONS
|
||||
# ================================================================
|
||||
|
||||
def _get_subscription_channels(tenant_id: str, channel_filters: list) -> list:
|
||||
"""
|
||||
Determine which Redis channels to subscribe to based on filters.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
channel_filters: List of channel patterns (e.g., ["inventory.alerts", "*.notifications"])
|
||||
|
||||
Returns:
|
||||
List of full channel names to subscribe to
|
||||
|
||||
Examples:
|
||||
>>> _get_subscription_channels("abc", ["inventory.alerts"])
|
||||
["tenant:abc:inventory.alerts"]
|
||||
|
||||
>>> _get_subscription_channels("abc", ["*.alerts"])
|
||||
["tenant:abc:inventory.alerts", "tenant:abc:production.alerts", ...]
|
||||
|
||||
>>> _get_subscription_channels("abc", [])
|
||||
["tenant:abc:inventory.alerts", "tenant:abc:inventory.notifications", ...]
|
||||
"""
|
||||
"""Determine which Redis channels to subscribe to based on filters"""
|
||||
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
|
||||
all_classes = ["alerts", "notifications"]
|
||||
|
||||
channels = []
|
||||
|
||||
if not channel_filters:
|
||||
@@ -271,70 +130,49 @@ def _get_subscription_channels(tenant_id: str, channel_filters: list) -> list:
|
||||
for domain in all_domains:
|
||||
for event_class in all_classes:
|
||||
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
|
||||
# Also subscribe to recommendations (tenant-wide)
|
||||
channels.append(f"tenant:{tenant_id}:recommendations")
|
||||
# Also subscribe to legacy channel for backward compatibility
|
||||
channels.append(f"alerts:{tenant_id}")
|
||||
channels.append(f"alerts:{tenant_id}") # Legacy
|
||||
return channels
|
||||
|
||||
# Parse filters and expand wildcards
|
||||
for filter_pattern in channel_filters:
|
||||
if filter_pattern == "*.*":
|
||||
# All channels
|
||||
for domain in all_domains:
|
||||
for event_class in all_classes:
|
||||
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
|
||||
channels.append(f"tenant:{tenant_id}:recommendations")
|
||||
|
||||
elif filter_pattern.endswith(".*"):
|
||||
# Domain wildcard (e.g., "inventory.*")
|
||||
domain = filter_pattern.split(".")[0]
|
||||
for event_class in all_classes:
|
||||
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
|
||||
|
||||
elif filter_pattern.startswith("*."):
|
||||
# Class wildcard (e.g., "*.alerts")
|
||||
event_class = filter_pattern.split(".")[1]
|
||||
if event_class == "recommendations":
|
||||
channels.append(f"tenant:{tenant_id}:recommendations")
|
||||
else:
|
||||
for domain in all_domains:
|
||||
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
|
||||
|
||||
elif filter_pattern == "recommendations":
|
||||
# Recommendations channel
|
||||
channels.append(f"tenant:{tenant_id}:recommendations")
|
||||
|
||||
else:
|
||||
# Specific channel (e.g., "inventory.alerts")
|
||||
channels.append(f"tenant:{tenant_id}:{filter_pattern}")
|
||||
|
||||
return list(set(channels)) # Remove duplicates
|
||||
return list(set(channels))
|
||||
|
||||
|
||||
async def _load_initial_state(redis_client, tenant_id: str, channel_filters: list) -> list:
|
||||
"""
|
||||
Load initial state from Redis cache based on channel filters.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client
|
||||
tenant_id: Tenant identifier
|
||||
channel_filters: List of channel patterns
|
||||
|
||||
Returns:
|
||||
List of initial events
|
||||
"""
|
||||
"""Load initial state from Redis cache based on channel filters"""
|
||||
initial_events = []
|
||||
|
||||
try:
|
||||
if not channel_filters:
|
||||
# Load from legacy cache if no filters (backward compat)
|
||||
# Legacy cache
|
||||
legacy_cache_key = f"active_alerts:{tenant_id}"
|
||||
cached_data = await redis_client.get(legacy_cache_key)
|
||||
if cached_data:
|
||||
return json.loads(cached_data)
|
||||
|
||||
# Also try loading from new domain-specific caches
|
||||
# New domain-specific caches
|
||||
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
|
||||
all_classes = ["alerts", "notifications"]
|
||||
|
||||
@@ -343,10 +181,9 @@ async def _load_initial_state(redis_client, tenant_id: str, channel_filters: lis
|
||||
cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s"
|
||||
cached_data = await redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
events = json.loads(cached_data)
|
||||
initial_events.extend(events)
|
||||
initial_events.extend(json.loads(cached_data))
|
||||
|
||||
# Load recommendations
|
||||
# Recommendations
|
||||
recommendations_cache_key = f"active_events:{tenant_id}:recommendations"
|
||||
cached_data = await redis_client.get(recommendations_cache_key)
|
||||
if cached_data:
|
||||
@@ -356,36 +193,29 @@ async def _load_initial_state(redis_client, tenant_id: str, channel_filters: lis
|
||||
|
||||
# Load based on specific filters
|
||||
for filter_pattern in channel_filters:
|
||||
# Extract domain and class from filter
|
||||
if "." in filter_pattern:
|
||||
parts = filter_pattern.split(".")
|
||||
domain = parts[0] if parts[0] != "*" else None
|
||||
event_class = parts[1] if len(parts) > 1 and parts[1] != "*" else None
|
||||
|
||||
if domain and event_class:
|
||||
# Specific cache (e.g., "inventory.alerts")
|
||||
cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s"
|
||||
cached_data = await redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
initial_events.extend(json.loads(cached_data))
|
||||
|
||||
elif domain and not event_class:
|
||||
# Domain wildcard (e.g., "inventory.*")
|
||||
for ec in ["alerts", "notifications"]:
|
||||
cache_key = f"active_events:{tenant_id}:{domain}.{ec}"
|
||||
cached_data = await redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
initial_events.extend(json.loads(cached_data))
|
||||
|
||||
elif not domain and event_class:
|
||||
# Class wildcard (e.g., "*.alerts")
|
||||
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
|
||||
for d in all_domains:
|
||||
cache_key = f"active_events:{tenant_id}:{d}.{event_class}s"
|
||||
cached_data = await redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
initial_events.extend(json.loads(cached_data))
|
||||
|
||||
elif filter_pattern == "recommendations":
|
||||
cache_key = f"active_events:{tenant_id}:recommendations"
|
||||
cached_data = await redis_client.get(cache_key)
|
||||
@@ -400,27 +230,14 @@ async def _load_initial_state(redis_client, tenant_id: str, channel_filters: lis
|
||||
|
||||
|
||||
def _determine_event_type(event_data: dict) -> str:
|
||||
"""
|
||||
Determine SSE event type from event data.
|
||||
|
||||
Args:
|
||||
event_data: Event data dictionary
|
||||
|
||||
Returns:
|
||||
SSE event type: 'alert', 'notification', or 'recommendation'
|
||||
"""
|
||||
# New event architecture uses 'event_class'
|
||||
"""Determine SSE event type from event data"""
|
||||
if 'event_class' in event_data:
|
||||
return event_data['event_class'] # 'alert', 'notification', or 'recommendation'
|
||||
|
||||
# Legacy format uses 'item_type'
|
||||
return event_data['event_class']
|
||||
if 'item_type' in event_data:
|
||||
if event_data['item_type'] == 'recommendation':
|
||||
return 'recommendation'
|
||||
else:
|
||||
return 'alert'
|
||||
|
||||
# Default to 'alert' for backward compatibility
|
||||
return 'alert'
|
||||
|
||||
|
||||
@@ -432,42 +249,25 @@ def _determine_event_type(event_data: dict) -> str:
|
||||
async def events_stream(
|
||||
request: Request,
|
||||
tenant_id: str,
|
||||
channels: str = None # Comma-separated channel filters (e.g., "inventory.alerts,production.notifications")
|
||||
channels: str = None
|
||||
):
|
||||
"""
|
||||
Server-Sent Events stream for real-time notifications with multi-channel support.
|
||||
|
||||
Authentication is handled by auth middleware via query param token.
|
||||
User context is available in request.state.user (injected by middleware).
|
||||
|
||||
Query Parameters:
|
||||
tenant_id: Tenant identifier (required)
|
||||
channels: Comma-separated channel filters (optional)
|
||||
Examples:
|
||||
- "inventory.alerts,production.notifications" - Specific channels
|
||||
- "*.alerts" - All alert channels
|
||||
- "inventory.*" - All inventory events
|
||||
- None - All channels (default, backward compatible)
|
||||
|
||||
New channel pattern: tenant:{tenant_id}:{domain}.{class}
|
||||
Examples:
|
||||
- tenant:abc:inventory.alerts
|
||||
- tenant:abc:production.notifications
|
||||
- tenant:abc:recommendations
|
||||
|
||||
Legacy channel (backward compat): alerts:{tenant_id}
|
||||
"""
|
||||
global redis_client
|
||||
|
||||
if not redis_client:
|
||||
raise HTTPException(status_code=503, detail="SSE service unavailable")
|
||||
|
||||
# Extract user context from request state (set by auth middleware)
|
||||
# Extract user context from request state
|
||||
user_context = request.state.user
|
||||
user_id = user_context.get('user_id')
|
||||
email = user_context.get('email')
|
||||
|
||||
# Validate tenant_id parameter
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="tenant_id query parameter is required")
|
||||
|
||||
@@ -479,79 +279,53 @@ async def events_stream(
|
||||
logger.info(f"SSE connection request for user {email}, tenant {tenant_id}, channels: {channel_filters or 'all'}")
|
||||
|
||||
async def event_generator():
|
||||
"""Generate server-sent events from Redis pub/sub with multi-channel support"""
|
||||
"""Generate server-sent events from Redis pub/sub"""
|
||||
pubsub = None
|
||||
try:
|
||||
# Create pubsub connection with resource monitoring
|
||||
pubsub = redis_client.pubsub()
|
||||
logger.debug(f"Created Redis pubsub connection for tenant: {tenant_id}")
|
||||
|
||||
# Monitor connection count
|
||||
try:
|
||||
connection_info = await redis_client.info('clients')
|
||||
connected_clients = connection_info.get('connected_clients', 'unknown')
|
||||
logger.debug(f"Redis connected clients: {connected_clients}")
|
||||
except Exception:
|
||||
# Don't fail if we can't get connection info
|
||||
pass
|
||||
|
||||
# Determine which channels to subscribe to
|
||||
# Determine channels
|
||||
subscription_channels = _get_subscription_channels(tenant_id, channel_filters)
|
||||
|
||||
# Subscribe to all determined channels
|
||||
# Subscribe
|
||||
if subscription_channels:
|
||||
await pubsub.subscribe(*subscription_channels)
|
||||
logger.info(f"Subscribed to {len(subscription_channels)} channels for tenant {tenant_id}")
|
||||
else:
|
||||
# Fallback to legacy channel if no channels specified
|
||||
legacy_channel = f"alerts:{tenant_id}"
|
||||
await pubsub.subscribe(legacy_channel)
|
||||
logger.info(f"Subscribed to legacy channel: {legacy_channel}")
|
||||
|
||||
# Send initial connection event
|
||||
# Connection event
|
||||
yield f"event: connection\n"
|
||||
yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'channels': subscription_channels or ['all'], 'timestamp': time.time()})}\n\n"
|
||||
|
||||
# Fetch and send initial state from cache (domain-specific or legacy)
|
||||
# Initial state
|
||||
initial_events = await _load_initial_state(redis_client, tenant_id, channel_filters)
|
||||
if initial_events:
|
||||
logger.info(f"Sending {len(initial_events)} initial events to tenant {tenant_id}")
|
||||
yield f"event: initial_state\n"
|
||||
yield f"data: {json.dumps(initial_events)}\n\n"
|
||||
else:
|
||||
# Send empty initial state for compatibility
|
||||
yield f"event: initial_state\n"
|
||||
yield f"data: {json.dumps([])}\n\n"
|
||||
yield f"event: initial_state\n"
|
||||
yield f"data: {json.dumps(initial_events)}\n\n"
|
||||
|
||||
heartbeat_counter = 0
|
||||
|
||||
while True:
|
||||
# Check if client has disconnected
|
||||
if await request.is_disconnected():
|
||||
logger.info(f"SSE client disconnected for tenant: {tenant_id}")
|
||||
break
|
||||
|
||||
try:
|
||||
# Get message from Redis with timeout
|
||||
message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0)
|
||||
|
||||
if message and message['type'] == 'message':
|
||||
# Forward the event from Redis
|
||||
event_data = json.loads(message['data'])
|
||||
|
||||
# Determine event type for SSE
|
||||
event_type = _determine_event_type(event_data)
|
||||
|
||||
# Add channel metadata for frontend routing
|
||||
event_data['_channel'] = message['channel'].decode('utf-8') if isinstance(message['channel'], bytes) else message['channel']
|
||||
|
||||
yield f"event: {event_type}\n"
|
||||
yield f"data: {json.dumps(event_data)}\n\n"
|
||||
|
||||
logger.debug(f"SSE event sent to tenant {tenant_id}: {event_type} - {event_data.get('title')}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat every 10 timeouts (100 seconds)
|
||||
heartbeat_counter += 1
|
||||
if heartbeat_counter >= 10:
|
||||
yield f"event: heartbeat\n"
|
||||
@@ -563,24 +337,13 @@ async def events_stream(
|
||||
except Exception as e:
|
||||
logger.error(f"SSE error for tenant {tenant_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
try:
|
||||
if pubsub:
|
||||
try:
|
||||
# Unsubscribe from all channels
|
||||
await pubsub.unsubscribe()
|
||||
logger.debug(f"Unsubscribed from Redis channels for tenant: {tenant_id}")
|
||||
except Exception as unsubscribe_error:
|
||||
logger.error(f"Failed to unsubscribe Redis pubsub for tenant {tenant_id}: {unsubscribe_error}")
|
||||
|
||||
try:
|
||||
# Close pubsub connection
|
||||
await pubsub.close()
|
||||
logger.debug(f"Closed Redis pubsub connection for tenant: {tenant_id}")
|
||||
except Exception as close_error:
|
||||
logger.error(f"Failed to close Redis pubsub for tenant {tenant_id}: {close_error}")
|
||||
logger.info(f"SSE connection closed for tenant: {tenant_id}")
|
||||
except Exception as finally_error:
|
||||
logger.error(f"Error in SSE cleanup for tenant {tenant_id}: {finally_error}")
|
||||
if pubsub:
|
||||
try:
|
||||
await pubsub.unsubscribe()
|
||||
await pubsub.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing pubsub: {e}")
|
||||
logger.info(f"SSE connection closed for tenant: {tenant_id}")
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -593,55 +356,35 @@ async def events_stream(
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# WEBSOCKET ROUTING FOR TRAINING SERVICE
|
||||
# ================================================================
|
||||
|
||||
@app.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
||||
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
|
||||
"""
|
||||
Simple WebSocket proxy with token verification only.
|
||||
Validates the token and forwards the connection to the training service.
|
||||
"""
|
||||
# Get token from query params
|
||||
"""WebSocket proxy with token verification for training service"""
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
logger.warning("WebSocket proxy rejected - missing token",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
await websocket.accept()
|
||||
await websocket.close(code=1008, reason="Authentication token required")
|
||||
return
|
||||
|
||||
# Verify token
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload or not payload.get('user_id'):
|
||||
logger.warning("WebSocket proxy rejected - invalid token",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
await websocket.accept()
|
||||
await websocket.close(code=1008, reason="Invalid token")
|
||||
return
|
||||
|
||||
logger.info("WebSocket proxy - token verified",
|
||||
user_id=payload.get('user_id'),
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("WebSocket proxy - token verification failed",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
await websocket.accept()
|
||||
await websocket.close(code=1008, reason="Token verification failed")
|
||||
return
|
||||
|
||||
# Accept the connection
|
||||
await websocket.accept()
|
||||
|
||||
# Build WebSocket URL to training service
|
||||
@@ -649,33 +392,24 @@ async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_
|
||||
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
|
||||
training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
|
||||
|
||||
logger.info("Gateway proxying WebSocket to training service",
|
||||
job_id=job_id,
|
||||
training_ws_url=training_ws_url.replace(token, '***'))
|
||||
|
||||
training_ws = None
|
||||
|
||||
try:
|
||||
# Connect to training service WebSocket
|
||||
import websockets
|
||||
from websockets.protocol import State
|
||||
|
||||
training_ws = await websockets.connect(
|
||||
training_ws_url,
|
||||
ping_interval=120, # Send ping every 2 minutes (tolerates long training operations)
|
||||
ping_timeout=60, # Wait up to 1 minute for pong (graceful timeout)
|
||||
close_timeout=60, # Increase close timeout for graceful shutdown
|
||||
ping_interval=120,
|
||||
ping_timeout=60,
|
||||
close_timeout=60,
|
||||
open_timeout=30
|
||||
)
|
||||
|
||||
logger.info("Gateway connected to training service WebSocket", job_id=job_id)
|
||||
|
||||
async def forward_frontend_to_training():
|
||||
"""Forward messages from frontend to training service"""
|
||||
try:
|
||||
while training_ws and training_ws.state == State.OPEN:
|
||||
data = await websocket.receive()
|
||||
|
||||
if data.get("type") == "websocket.receive":
|
||||
if "text" in data:
|
||||
await training_ws.send(data["text"])
|
||||
@@ -683,30 +417,17 @@ async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_
|
||||
await training_ws.send(data["bytes"])
|
||||
elif data.get("type") == "websocket.disconnect":
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug("Frontend to training forward ended", error=str(e))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def forward_training_to_frontend():
|
||||
"""Forward messages from training service to frontend"""
|
||||
message_count = 0
|
||||
try:
|
||||
while training_ws and training_ws.state == State.OPEN:
|
||||
message = await training_ws.recv()
|
||||
await websocket.send_text(message)
|
||||
message_count += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Log every 10th message to track connectivity
|
||||
if message_count % 10 == 0:
|
||||
logger.debug("WebSocket proxy active",
|
||||
job_id=job_id,
|
||||
messages_forwarded=message_count)
|
||||
except Exception as e:
|
||||
logger.info("Training to frontend forward ended",
|
||||
job_id=job_id,
|
||||
messages_forwarded=message_count,
|
||||
error=str(e))
|
||||
|
||||
# Run both forwarding tasks concurrently
|
||||
await asyncio.gather(
|
||||
forward_frontend_to_training(),
|
||||
forward_training_to_frontend(),
|
||||
@@ -716,20 +437,17 @@ async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_
|
||||
except Exception as e:
|
||||
logger.error("WebSocket proxy error", job_id=job_id, error=str(e))
|
||||
finally:
|
||||
# Cleanup
|
||||
if training_ws and training_ws.state == State.OPEN:
|
||||
try:
|
||||
await training_ws.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if not websocket.client_state.name == 'DISCONNECTED':
|
||||
await websocket.close(code=1000, reason="Proxy closed")
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.info("WebSocket proxy connection closed", job_id=job_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
Reference in New Issue
Block a user