""" WebSocket Connection Manager for Training Service Manages WebSocket connections and broadcasts RabbitMQ events to connected clients """ import asyncio import json from typing import Dict, Set from fastapi import WebSocket import structlog logger = structlog.get_logger() class WebSocketConnectionManager: """ Simple WebSocket connection manager. Manages connections per job_id and broadcasts messages to all connected clients. """ def __init__(self): # Structure: {job_id: {websocket_id: WebSocket}} self._connections: Dict[str, Dict[int, WebSocket]] = {} self._lock = asyncio.Lock() # Store latest event for each job to provide initial state self._latest_events: Dict[str, dict] = {} async def connect(self, job_id: str, websocket: WebSocket) -> None: """Register a new WebSocket connection for a job""" await websocket.accept() async with self._lock: if job_id not in self._connections: self._connections[job_id] = {} ws_id = id(websocket) self._connections[job_id][ws_id] = websocket # Send initial state if available if job_id in self._latest_events: try: await websocket.send_json({ "type": "initial_state", "job_id": job_id, "data": self._latest_events[job_id] }) except Exception as e: logger.warning("Failed to send initial state to new connection", error=str(e)) logger.info("WebSocket connected", job_id=job_id, websocket_id=ws_id, total_connections=len(self._connections[job_id])) async def disconnect(self, job_id: str, websocket: WebSocket) -> None: """Remove a WebSocket connection""" async with self._lock: if job_id in self._connections: ws_id = id(websocket) self._connections[job_id].pop(ws_id, None) # Clean up empty job connections if not self._connections[job_id]: del self._connections[job_id] logger.info("WebSocket disconnected", job_id=job_id, websocket_id=ws_id, remaining_connections=len(self._connections.get(job_id, {}))) async def broadcast(self, job_id: str, message: dict) -> int: """ Broadcast a message to all connections for a specific job. Returns the number of successful broadcasts. """ # Store the latest event for this job to provide initial state to new connections if message.get('type') != 'initial_state': # Don't store initial_state messages self._latest_events[job_id] = message if job_id not in self._connections: logger.debug("No active connections for job", job_id=job_id) return 0 connections = list(self._connections[job_id].values()) successful_sends = 0 failed_websockets = [] for websocket in connections: try: await websocket.send_json(message) successful_sends += 1 except Exception as e: logger.warning("Failed to send message to WebSocket", job_id=job_id, error=str(e)) failed_websockets.append(websocket) # Clean up failed connections if failed_websockets: async with self._lock: for ws in failed_websockets: ws_id = id(ws) self._connections[job_id].pop(ws_id, None) if successful_sends > 0: logger.info("Broadcasted message to WebSocket clients", job_id=job_id, message_type=message.get('type'), successful_sends=successful_sends, failed_sends=len(failed_websockets)) return successful_sends def get_connection_count(self, job_id: str) -> int: """Get the number of active connections for a job""" return len(self._connections.get(job_id, {})) # Global singleton instance websocket_manager = WebSocketConnectionManager()