Files
bakery-ia/services/training/app/websocket/manager.py

301 lines
11 KiB
Python
Raw Normal View History

"""
WebSocket Connection Manager for Training Service
Manages WebSocket connections and broadcasts RabbitMQ events to connected clients
2026-01-18 09:02:27 +01:00
HORIZONTAL SCALING:
- Uses Redis pub/sub for cross-pod WebSocket broadcasting
- Each pod subscribes to a Redis channel and broadcasts to its local connections
- Events published to Redis are received by all pods, ensuring clients on any
pod receive events from training jobs running on any other pod
"""
import asyncio
import json
2026-01-18 09:02:27 +01:00
import os
from typing import Dict, Optional
from fastapi import WebSocket
import structlog
logger = structlog.get_logger()
2026-01-18 09:02:27 +01:00
# Redis pub/sub channel for WebSocket events
REDIS_WEBSOCKET_CHANNEL = "training:websocket:events"
class WebSocketConnectionManager:
"""
2026-01-18 09:02:27 +01:00
WebSocket connection manager with Redis pub/sub for horizontal scaling.
In a multi-pod deployment:
1. Events are published to Redis pub/sub (not just local broadcast)
2. Each pod subscribes to Redis and broadcasts to its local WebSocket connections
3. This ensures clients connected to any pod receive events from any pod
Flow:
- RabbitMQ event Pod A receives Pod A publishes to Redis
- Redis pub/sub All pods receive Each pod broadcasts to local WebSockets
"""
def __init__(self):
# Structure: {job_id: {websocket_id: WebSocket}}
self._connections: Dict[str, Dict[int, WebSocket]] = {}
self._lock = asyncio.Lock()
# Store latest event for each job to provide initial state
self._latest_events: Dict[str, dict] = {}
2026-01-18 09:02:27 +01:00
# Redis client for pub/sub
self._redis: Optional[object] = None
self._pubsub: Optional[object] = None
self._subscriber_task: Optional[asyncio.Task] = None
self._running = False
self._instance_id = f"{os.environ.get('HOSTNAME', 'unknown')}:{os.getpid()}"
async def initialize_redis(self, redis_url: str) -> bool:
"""
Initialize Redis connection for cross-pod pub/sub.
Args:
redis_url: Redis connection URL
Returns:
True if successful, False otherwise
"""
try:
import redis.asyncio as redis_async
self._redis = redis_async.from_url(redis_url, decode_responses=True)
await self._redis.ping()
# Create pub/sub subscriber
self._pubsub = self._redis.pubsub()
await self._pubsub.subscribe(REDIS_WEBSOCKET_CHANNEL)
# Start subscriber task
self._running = True
self._subscriber_task = asyncio.create_task(self._redis_subscriber_loop())
logger.info("Redis pub/sub initialized for WebSocket broadcasting",
instance_id=self._instance_id,
channel=REDIS_WEBSOCKET_CHANNEL)
return True
except Exception as e:
logger.error("Failed to initialize Redis pub/sub",
error=str(e),
instance_id=self._instance_id)
return False
async def shutdown(self):
"""Shutdown Redis pub/sub connection"""
self._running = False
if self._subscriber_task:
self._subscriber_task.cancel()
try:
await self._subscriber_task
except asyncio.CancelledError:
pass
if self._pubsub:
await self._pubsub.unsubscribe(REDIS_WEBSOCKET_CHANNEL)
await self._pubsub.close()
if self._redis:
await self._redis.close()
logger.info("Redis pub/sub shutdown complete",
instance_id=self._instance_id)
async def _redis_subscriber_loop(self):
"""Background task to receive Redis pub/sub messages and broadcast locally"""
try:
while self._running:
try:
message = await self._pubsub.get_message(
ignore_subscribe_messages=True,
timeout=1.0
)
if message and message['type'] == 'message':
await self._handle_redis_message(message['data'])
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error in Redis subscriber loop",
error=str(e),
instance_id=self._instance_id)
await asyncio.sleep(1) # Backoff on error
except asyncio.CancelledError:
pass
logger.info("Redis subscriber loop stopped",
instance_id=self._instance_id)
async def _handle_redis_message(self, data: str):
"""Handle a message received from Redis pub/sub"""
try:
payload = json.loads(data)
job_id = payload.get('job_id')
message = payload.get('message')
source_instance = payload.get('source_instance')
if not job_id or not message:
return
# Log cross-pod message
if source_instance != self._instance_id:
logger.debug("Received cross-pod WebSocket event",
job_id=job_id,
source_instance=source_instance,
local_instance=self._instance_id)
# Broadcast to local WebSocket connections
await self._broadcast_local(job_id, message)
except json.JSONDecodeError as e:
logger.warning("Invalid JSON in Redis message", error=str(e))
except Exception as e:
logger.error("Error handling Redis message", error=str(e))
async def connect(self, job_id: str, websocket: WebSocket) -> None:
"""Register a new WebSocket connection for a job"""
await websocket.accept()
async with self._lock:
if job_id not in self._connections:
self._connections[job_id] = {}
ws_id = id(websocket)
self._connections[job_id][ws_id] = websocket
# Send initial state if available
if job_id in self._latest_events:
try:
await websocket.send_json({
"type": "initial_state",
"job_id": job_id,
"data": self._latest_events[job_id]
})
except Exception as e:
logger.warning("Failed to send initial state to new connection", error=str(e))
logger.info("WebSocket connected",
job_id=job_id,
websocket_id=ws_id,
2026-01-18 09:02:27 +01:00
total_connections=len(self._connections[job_id]),
instance_id=self._instance_id)
async def disconnect(self, job_id: str, websocket: WebSocket) -> None:
"""Remove a WebSocket connection"""
async with self._lock:
if job_id in self._connections:
ws_id = id(websocket)
self._connections[job_id].pop(ws_id, None)
# Clean up empty job connections
if not self._connections[job_id]:
del self._connections[job_id]
logger.info("WebSocket disconnected",
job_id=job_id,
websocket_id=ws_id,
2026-01-18 09:02:27 +01:00
remaining_connections=len(self._connections.get(job_id, {})),
instance_id=self._instance_id)
async def broadcast(self, job_id: str, message: dict) -> int:
"""
2026-01-18 09:02:27 +01:00
Broadcast a message to all connections for a specific job across ALL pods.
If Redis is configured, publishes to Redis pub/sub which then broadcasts
to all pods. Otherwise, falls back to local-only broadcast.
Returns the number of successful local broadcasts.
"""
# Store the latest event for this job to provide initial state to new connections
2026-01-18 09:02:27 +01:00
if message.get('type') != 'initial_state':
self._latest_events[job_id] = message
2026-01-18 09:02:27 +01:00
# If Redis is available, publish to Redis for cross-pod broadcast
if self._redis:
try:
payload = json.dumps({
'job_id': job_id,
'message': message,
'source_instance': self._instance_id
})
await self._redis.publish(REDIS_WEBSOCKET_CHANNEL, payload)
logger.debug("Published WebSocket event to Redis",
job_id=job_id,
message_type=message.get('type'),
instance_id=self._instance_id)
# Return 0 here because the actual broadcast happens via subscriber
# The count will be from _broadcast_local when the message is received
return 0
except Exception as e:
logger.warning("Failed to publish to Redis, falling back to local broadcast",
error=str(e),
job_id=job_id)
# Fall through to local broadcast
# Local-only broadcast (when Redis is not available)
return await self._broadcast_local(job_id, message)
async def _broadcast_local(self, job_id: str, message: dict) -> int:
"""
Broadcast a message to local WebSocket connections only.
This is called either directly (no Redis) or from Redis subscriber.
"""
if job_id not in self._connections:
2026-01-18 09:02:27 +01:00
logger.debug("No active local connections for job",
job_id=job_id,
instance_id=self._instance_id)
return 0
connections = list(self._connections[job_id].values())
successful_sends = 0
failed_websockets = []
for websocket in connections:
try:
await websocket.send_json(message)
successful_sends += 1
except Exception as e:
logger.warning("Failed to send message to WebSocket",
job_id=job_id,
error=str(e))
failed_websockets.append(websocket)
# Clean up failed connections
if failed_websockets:
async with self._lock:
for ws in failed_websockets:
ws_id = id(ws)
self._connections[job_id].pop(ws_id, None)
if successful_sends > 0:
2026-01-18 09:02:27 +01:00
logger.info("Broadcasted message to local WebSocket clients",
job_id=job_id,
message_type=message.get('type'),
successful_sends=successful_sends,
2026-01-18 09:02:27 +01:00
failed_sends=len(failed_websockets),
instance_id=self._instance_id)
return successful_sends
def get_connection_count(self, job_id: str) -> int:
2026-01-18 09:02:27 +01:00
"""Get the number of active local connections for a job"""
return len(self._connections.get(job_id, {}))
2026-01-18 09:02:27 +01:00
def get_total_connection_count(self) -> int:
"""Get total number of active connections across all jobs"""
return sum(len(conns) for conns in self._connections.values())
def is_redis_enabled(self) -> bool:
"""Check if Redis pub/sub is enabled"""
return self._redis is not None and self._running
# Global singleton instance
websocket_manager = WebSocketConnectionManager()