301 lines
11 KiB
Python
301 lines
11 KiB
Python
"""
|
|
WebSocket Connection Manager for Training Service
|
|
Manages WebSocket connections and broadcasts RabbitMQ events to connected clients
|
|
|
|
HORIZONTAL SCALING:
|
|
- Uses Redis pub/sub for cross-pod WebSocket broadcasting
|
|
- Each pod subscribes to a Redis channel and broadcasts to its local connections
|
|
- Events published to Redis are received by all pods, ensuring clients on any
|
|
pod receive events from training jobs running on any other pod
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
from typing import Dict, Optional
|
|
from fastapi import WebSocket
|
|
import structlog
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
# Redis pub/sub channel for WebSocket events
|
|
REDIS_WEBSOCKET_CHANNEL = "training:websocket:events"
|
|
|
|
|
|
class WebSocketConnectionManager:
|
|
"""
|
|
WebSocket connection manager with Redis pub/sub for horizontal scaling.
|
|
|
|
In a multi-pod deployment:
|
|
1. Events are published to Redis pub/sub (not just local broadcast)
|
|
2. Each pod subscribes to Redis and broadcasts to its local WebSocket connections
|
|
3. This ensures clients connected to any pod receive events from any pod
|
|
|
|
Flow:
|
|
- RabbitMQ event → Pod A receives → Pod A publishes to Redis
|
|
- Redis pub/sub → All pods receive → Each pod broadcasts to local WebSockets
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Structure: {job_id: {websocket_id: WebSocket}}
|
|
self._connections: Dict[str, Dict[int, WebSocket]] = {}
|
|
self._lock = asyncio.Lock()
|
|
# Store latest event for each job to provide initial state
|
|
self._latest_events: Dict[str, dict] = {}
|
|
# Redis client for pub/sub
|
|
self._redis: Optional[object] = None
|
|
self._pubsub: Optional[object] = None
|
|
self._subscriber_task: Optional[asyncio.Task] = None
|
|
self._running = False
|
|
self._instance_id = f"{os.environ.get('HOSTNAME', 'unknown')}:{os.getpid()}"
|
|
|
|
async def initialize_redis(self, redis_url: str) -> bool:
|
|
"""
|
|
Initialize Redis connection for cross-pod pub/sub.
|
|
|
|
Args:
|
|
redis_url: Redis connection URL
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
import redis.asyncio as redis_async
|
|
|
|
self._redis = redis_async.from_url(redis_url, decode_responses=True)
|
|
await self._redis.ping()
|
|
|
|
# Create pub/sub subscriber
|
|
self._pubsub = self._redis.pubsub()
|
|
await self._pubsub.subscribe(REDIS_WEBSOCKET_CHANNEL)
|
|
|
|
# Start subscriber task
|
|
self._running = True
|
|
self._subscriber_task = asyncio.create_task(self._redis_subscriber_loop())
|
|
|
|
logger.info("Redis pub/sub initialized for WebSocket broadcasting",
|
|
instance_id=self._instance_id,
|
|
channel=REDIS_WEBSOCKET_CHANNEL)
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to initialize Redis pub/sub",
|
|
error=str(e),
|
|
instance_id=self._instance_id)
|
|
return False
|
|
|
|
async def shutdown(self):
|
|
"""Shutdown Redis pub/sub connection"""
|
|
self._running = False
|
|
|
|
if self._subscriber_task:
|
|
self._subscriber_task.cancel()
|
|
try:
|
|
await self._subscriber_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
if self._pubsub:
|
|
await self._pubsub.unsubscribe(REDIS_WEBSOCKET_CHANNEL)
|
|
await self._pubsub.close()
|
|
|
|
if self._redis:
|
|
await self._redis.close()
|
|
|
|
logger.info("Redis pub/sub shutdown complete",
|
|
instance_id=self._instance_id)
|
|
|
|
async def _redis_subscriber_loop(self):
|
|
"""Background task to receive Redis pub/sub messages and broadcast locally"""
|
|
try:
|
|
while self._running:
|
|
try:
|
|
message = await self._pubsub.get_message(
|
|
ignore_subscribe_messages=True,
|
|
timeout=1.0
|
|
)
|
|
|
|
if message and message['type'] == 'message':
|
|
await self._handle_redis_message(message['data'])
|
|
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error("Error in Redis subscriber loop",
|
|
error=str(e),
|
|
instance_id=self._instance_id)
|
|
await asyncio.sleep(1) # Backoff on error
|
|
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
logger.info("Redis subscriber loop stopped",
|
|
instance_id=self._instance_id)
|
|
|
|
async def _handle_redis_message(self, data: str):
|
|
"""Handle a message received from Redis pub/sub"""
|
|
try:
|
|
payload = json.loads(data)
|
|
job_id = payload.get('job_id')
|
|
message = payload.get('message')
|
|
source_instance = payload.get('source_instance')
|
|
|
|
if not job_id or not message:
|
|
return
|
|
|
|
# Log cross-pod message
|
|
if source_instance != self._instance_id:
|
|
logger.debug("Received cross-pod WebSocket event",
|
|
job_id=job_id,
|
|
source_instance=source_instance,
|
|
local_instance=self._instance_id)
|
|
|
|
# Broadcast to local WebSocket connections
|
|
await self._broadcast_local(job_id, message)
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.warning("Invalid JSON in Redis message", error=str(e))
|
|
except Exception as e:
|
|
logger.error("Error handling Redis message", error=str(e))
|
|
|
|
async def connect(self, job_id: str, websocket: WebSocket) -> None:
|
|
"""Register a new WebSocket connection for a job"""
|
|
await websocket.accept()
|
|
|
|
async with self._lock:
|
|
if job_id not in self._connections:
|
|
self._connections[job_id] = {}
|
|
|
|
ws_id = id(websocket)
|
|
self._connections[job_id][ws_id] = websocket
|
|
|
|
# Send initial state if available
|
|
if job_id in self._latest_events:
|
|
try:
|
|
await websocket.send_json({
|
|
"type": "initial_state",
|
|
"job_id": job_id,
|
|
"data": self._latest_events[job_id]
|
|
})
|
|
except Exception as e:
|
|
logger.warning("Failed to send initial state to new connection", error=str(e))
|
|
|
|
logger.info("WebSocket connected",
|
|
job_id=job_id,
|
|
websocket_id=ws_id,
|
|
total_connections=len(self._connections[job_id]),
|
|
instance_id=self._instance_id)
|
|
|
|
async def disconnect(self, job_id: str, websocket: WebSocket) -> None:
|
|
"""Remove a WebSocket connection"""
|
|
async with self._lock:
|
|
if job_id in self._connections:
|
|
ws_id = id(websocket)
|
|
self._connections[job_id].pop(ws_id, None)
|
|
|
|
# Clean up empty job connections
|
|
if not self._connections[job_id]:
|
|
del self._connections[job_id]
|
|
|
|
logger.info("WebSocket disconnected",
|
|
job_id=job_id,
|
|
websocket_id=ws_id,
|
|
remaining_connections=len(self._connections.get(job_id, {})),
|
|
instance_id=self._instance_id)
|
|
|
|
async def broadcast(self, job_id: str, message: dict) -> int:
|
|
"""
|
|
Broadcast a message to all connections for a specific job across ALL pods.
|
|
|
|
If Redis is configured, publishes to Redis pub/sub which then broadcasts
|
|
to all pods. Otherwise, falls back to local-only broadcast.
|
|
|
|
Returns the number of successful local broadcasts.
|
|
"""
|
|
# Store the latest event for this job to provide initial state to new connections
|
|
if message.get('type') != 'initial_state':
|
|
self._latest_events[job_id] = message
|
|
|
|
# If Redis is available, publish to Redis for cross-pod broadcast
|
|
if self._redis:
|
|
try:
|
|
payload = json.dumps({
|
|
'job_id': job_id,
|
|
'message': message,
|
|
'source_instance': self._instance_id
|
|
})
|
|
await self._redis.publish(REDIS_WEBSOCKET_CHANNEL, payload)
|
|
logger.debug("Published WebSocket event to Redis",
|
|
job_id=job_id,
|
|
message_type=message.get('type'),
|
|
instance_id=self._instance_id)
|
|
# Return 0 here because the actual broadcast happens via subscriber
|
|
# The count will be from _broadcast_local when the message is received
|
|
return 0
|
|
except Exception as e:
|
|
logger.warning("Failed to publish to Redis, falling back to local broadcast",
|
|
error=str(e),
|
|
job_id=job_id)
|
|
# Fall through to local broadcast
|
|
|
|
# Local-only broadcast (when Redis is not available)
|
|
return await self._broadcast_local(job_id, message)
|
|
|
|
async def _broadcast_local(self, job_id: str, message: dict) -> int:
|
|
"""
|
|
Broadcast a message to local WebSocket connections only.
|
|
This is called either directly (no Redis) or from Redis subscriber.
|
|
"""
|
|
if job_id not in self._connections:
|
|
logger.debug("No active local connections for job",
|
|
job_id=job_id,
|
|
instance_id=self._instance_id)
|
|
return 0
|
|
|
|
connections = list(self._connections[job_id].values())
|
|
successful_sends = 0
|
|
failed_websockets = []
|
|
|
|
for websocket in connections:
|
|
try:
|
|
await websocket.send_json(message)
|
|
successful_sends += 1
|
|
except Exception as e:
|
|
logger.warning("Failed to send message to WebSocket",
|
|
job_id=job_id,
|
|
error=str(e))
|
|
failed_websockets.append(websocket)
|
|
|
|
# Clean up failed connections
|
|
if failed_websockets:
|
|
async with self._lock:
|
|
for ws in failed_websockets:
|
|
ws_id = id(ws)
|
|
self._connections[job_id].pop(ws_id, None)
|
|
|
|
if successful_sends > 0:
|
|
logger.info("Broadcasted message to local WebSocket clients",
|
|
job_id=job_id,
|
|
message_type=message.get('type'),
|
|
successful_sends=successful_sends,
|
|
failed_sends=len(failed_websockets),
|
|
instance_id=self._instance_id)
|
|
|
|
return successful_sends
|
|
|
|
def get_connection_count(self, job_id: str) -> int:
|
|
"""Get the number of active local connections for a job"""
|
|
return len(self._connections.get(job_id, {}))
|
|
|
|
def get_total_connection_count(self) -> int:
|
|
"""Get total number of active connections across all jobs"""
|
|
return sum(len(conns) for conns in self._connections.values())
|
|
|
|
def is_redis_enabled(self) -> bool:
|
|
"""Check if Redis pub/sub is enabled"""
|
|
return self._redis is not None and self._running
|
|
|
|
|
|
# Global singleton instance
|
|
websocket_manager = WebSocketConnectionManager()
|