Files
bakery-ia/services/training/app/websocket/manager.py

121 lines
4.3 KiB
Python

"""
WebSocket Connection Manager for Training Service
Manages WebSocket connections and broadcasts RabbitMQ events to connected clients
"""
import asyncio
import json
from typing import Dict, Set
from fastapi import WebSocket
import structlog
logger = structlog.get_logger()
class WebSocketConnectionManager:
"""
Simple WebSocket connection manager.
Manages connections per job_id and broadcasts messages to all connected clients.
"""
def __init__(self):
# Structure: {job_id: {websocket_id: WebSocket}}
self._connections: Dict[str, Dict[int, WebSocket]] = {}
self._lock = asyncio.Lock()
# Store latest event for each job to provide initial state
self._latest_events: Dict[str, dict] = {}
async def connect(self, job_id: str, websocket: WebSocket) -> None:
"""Register a new WebSocket connection for a job"""
await websocket.accept()
async with self._lock:
if job_id not in self._connections:
self._connections[job_id] = {}
ws_id = id(websocket)
self._connections[job_id][ws_id] = websocket
# Send initial state if available
if job_id in self._latest_events:
try:
await websocket.send_json({
"type": "initial_state",
"job_id": job_id,
"data": self._latest_events[job_id]
})
except Exception as e:
logger.warning("Failed to send initial state to new connection", error=str(e))
logger.info("WebSocket connected",
job_id=job_id,
websocket_id=ws_id,
total_connections=len(self._connections[job_id]))
async def disconnect(self, job_id: str, websocket: WebSocket) -> None:
"""Remove a WebSocket connection"""
async with self._lock:
if job_id in self._connections:
ws_id = id(websocket)
self._connections[job_id].pop(ws_id, None)
# Clean up empty job connections
if not self._connections[job_id]:
del self._connections[job_id]
logger.info("WebSocket disconnected",
job_id=job_id,
websocket_id=ws_id,
remaining_connections=len(self._connections.get(job_id, {})))
async def broadcast(self, job_id: str, message: dict) -> int:
"""
Broadcast a message to all connections for a specific job.
Returns the number of successful broadcasts.
"""
# Store the latest event for this job to provide initial state to new connections
if message.get('type') != 'initial_state': # Don't store initial_state messages
self._latest_events[job_id] = message
if job_id not in self._connections:
logger.debug("No active connections for job", job_id=job_id)
return 0
connections = list(self._connections[job_id].values())
successful_sends = 0
failed_websockets = []
for websocket in connections:
try:
await websocket.send_json(message)
successful_sends += 1
except Exception as e:
logger.warning("Failed to send message to WebSocket",
job_id=job_id,
error=str(e))
failed_websockets.append(websocket)
# Clean up failed connections
if failed_websockets:
async with self._lock:
for ws in failed_websockets:
ws_id = id(ws)
self._connections[job_id].pop(ws_id, None)
if successful_sends > 0:
logger.info("Broadcasted message to WebSocket clients",
job_id=job_id,
message_type=message.get('type'),
successful_sends=successful_sends,
failed_sends=len(failed_websockets))
return successful_sends
def get_connection_count(self, job_id: str) -> int:
"""Get the number of active connections for a job"""
return len(self._connections.get(job_id, {}))
# Global singleton instance
websocket_manager = WebSocketConnectionManager()