121 lines
4.3 KiB
Python
121 lines
4.3 KiB
Python
|
|
"""
|
||
|
|
WebSocket Connection Manager for Training Service
|
||
|
|
Manages WebSocket connections and broadcasts RabbitMQ events to connected clients
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
from typing import Dict, Set
|
||
|
|
from fastapi import WebSocket
|
||
|
|
import structlog
|
||
|
|
|
||
|
|
logger = structlog.get_logger()
|
||
|
|
|
||
|
|
|
||
|
|
class WebSocketConnectionManager:
|
||
|
|
"""
|
||
|
|
Simple WebSocket connection manager.
|
||
|
|
Manages connections per job_id and broadcasts messages to all connected clients.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
# Structure: {job_id: {websocket_id: WebSocket}}
|
||
|
|
self._connections: Dict[str, Dict[int, WebSocket]] = {}
|
||
|
|
self._lock = asyncio.Lock()
|
||
|
|
# Store latest event for each job to provide initial state
|
||
|
|
self._latest_events: Dict[str, dict] = {}
|
||
|
|
|
||
|
|
async def connect(self, job_id: str, websocket: WebSocket) -> None:
|
||
|
|
"""Register a new WebSocket connection for a job"""
|
||
|
|
await websocket.accept()
|
||
|
|
|
||
|
|
async with self._lock:
|
||
|
|
if job_id not in self._connections:
|
||
|
|
self._connections[job_id] = {}
|
||
|
|
|
||
|
|
ws_id = id(websocket)
|
||
|
|
self._connections[job_id][ws_id] = websocket
|
||
|
|
|
||
|
|
# Send initial state if available
|
||
|
|
if job_id in self._latest_events:
|
||
|
|
try:
|
||
|
|
await websocket.send_json({
|
||
|
|
"type": "initial_state",
|
||
|
|
"job_id": job_id,
|
||
|
|
"data": self._latest_events[job_id]
|
||
|
|
})
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning("Failed to send initial state to new connection", error=str(e))
|
||
|
|
|
||
|
|
logger.info("WebSocket connected",
|
||
|
|
job_id=job_id,
|
||
|
|
websocket_id=ws_id,
|
||
|
|
total_connections=len(self._connections[job_id]))
|
||
|
|
|
||
|
|
async def disconnect(self, job_id: str, websocket: WebSocket) -> None:
|
||
|
|
"""Remove a WebSocket connection"""
|
||
|
|
async with self._lock:
|
||
|
|
if job_id in self._connections:
|
||
|
|
ws_id = id(websocket)
|
||
|
|
self._connections[job_id].pop(ws_id, None)
|
||
|
|
|
||
|
|
# Clean up empty job connections
|
||
|
|
if not self._connections[job_id]:
|
||
|
|
del self._connections[job_id]
|
||
|
|
|
||
|
|
logger.info("WebSocket disconnected",
|
||
|
|
job_id=job_id,
|
||
|
|
websocket_id=ws_id,
|
||
|
|
remaining_connections=len(self._connections.get(job_id, {})))
|
||
|
|
|
||
|
|
async def broadcast(self, job_id: str, message: dict) -> int:
|
||
|
|
"""
|
||
|
|
Broadcast a message to all connections for a specific job.
|
||
|
|
Returns the number of successful broadcasts.
|
||
|
|
"""
|
||
|
|
# Store the latest event for this job to provide initial state to new connections
|
||
|
|
if message.get('type') != 'initial_state': # Don't store initial_state messages
|
||
|
|
self._latest_events[job_id] = message
|
||
|
|
|
||
|
|
if job_id not in self._connections:
|
||
|
|
logger.debug("No active connections for job", job_id=job_id)
|
||
|
|
return 0
|
||
|
|
|
||
|
|
connections = list(self._connections[job_id].values())
|
||
|
|
successful_sends = 0
|
||
|
|
failed_websockets = []
|
||
|
|
|
||
|
|
for websocket in connections:
|
||
|
|
try:
|
||
|
|
await websocket.send_json(message)
|
||
|
|
successful_sends += 1
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning("Failed to send message to WebSocket",
|
||
|
|
job_id=job_id,
|
||
|
|
error=str(e))
|
||
|
|
failed_websockets.append(websocket)
|
||
|
|
|
||
|
|
# Clean up failed connections
|
||
|
|
if failed_websockets:
|
||
|
|
async with self._lock:
|
||
|
|
for ws in failed_websockets:
|
||
|
|
ws_id = id(ws)
|
||
|
|
self._connections[job_id].pop(ws_id, None)
|
||
|
|
|
||
|
|
if successful_sends > 0:
|
||
|
|
logger.info("Broadcasted message to WebSocket clients",
|
||
|
|
job_id=job_id,
|
||
|
|
message_type=message.get('type'),
|
||
|
|
successful_sends=successful_sends,
|
||
|
|
failed_sends=len(failed_websockets))
|
||
|
|
|
||
|
|
return successful_sends
|
||
|
|
|
||
|
|
def get_connection_count(self, job_id: str) -> int:
|
||
|
|
"""Get the number of active connections for a job"""
|
||
|
|
return len(self._connections.get(job_id, {}))
|
||
|
|
|
||
|
|
|
||
|
|
# Global singleton instance
|
||
|
|
websocket_manager = WebSocketConnectionManager()
|