REFACTOR external service and improve websocket training
This commit is contained in:
120
services/training/app/websocket/manager.py
Normal file
120
services/training/app/websocket/manager.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
WebSocket Connection Manager for Training Service
|
||||
Manages WebSocket connections and broadcasts RabbitMQ events to connected clients
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, Set
|
||||
from fastapi import WebSocket
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class WebSocketConnectionManager:
|
||||
"""
|
||||
Simple WebSocket connection manager.
|
||||
Manages connections per job_id and broadcasts messages to all connected clients.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Structure: {job_id: {websocket_id: WebSocket}}
|
||||
self._connections: Dict[str, Dict[int, WebSocket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
# Store latest event for each job to provide initial state
|
||||
self._latest_events: Dict[str, dict] = {}
|
||||
|
||||
async def connect(self, job_id: str, websocket: WebSocket) -> None:
|
||||
"""Register a new WebSocket connection for a job"""
|
||||
await websocket.accept()
|
||||
|
||||
async with self._lock:
|
||||
if job_id not in self._connections:
|
||||
self._connections[job_id] = {}
|
||||
|
||||
ws_id = id(websocket)
|
||||
self._connections[job_id][ws_id] = websocket
|
||||
|
||||
# Send initial state if available
|
||||
if job_id in self._latest_events:
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "initial_state",
|
||||
"job_id": job_id,
|
||||
"data": self._latest_events[job_id]
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send initial state to new connection", error=str(e))
|
||||
|
||||
logger.info("WebSocket connected",
|
||||
job_id=job_id,
|
||||
websocket_id=ws_id,
|
||||
total_connections=len(self._connections[job_id]))
|
||||
|
||||
async def disconnect(self, job_id: str, websocket: WebSocket) -> None:
|
||||
"""Remove a WebSocket connection"""
|
||||
async with self._lock:
|
||||
if job_id in self._connections:
|
||||
ws_id = id(websocket)
|
||||
self._connections[job_id].pop(ws_id, None)
|
||||
|
||||
# Clean up empty job connections
|
||||
if not self._connections[job_id]:
|
||||
del self._connections[job_id]
|
||||
|
||||
logger.info("WebSocket disconnected",
|
||||
job_id=job_id,
|
||||
websocket_id=ws_id,
|
||||
remaining_connections=len(self._connections.get(job_id, {})))
|
||||
|
||||
async def broadcast(self, job_id: str, message: dict) -> int:
|
||||
"""
|
||||
Broadcast a message to all connections for a specific job.
|
||||
Returns the number of successful broadcasts.
|
||||
"""
|
||||
# Store the latest event for this job to provide initial state to new connections
|
||||
if message.get('type') != 'initial_state': # Don't store initial_state messages
|
||||
self._latest_events[job_id] = message
|
||||
|
||||
if job_id not in self._connections:
|
||||
logger.debug("No active connections for job", job_id=job_id)
|
||||
return 0
|
||||
|
||||
connections = list(self._connections[job_id].values())
|
||||
successful_sends = 0
|
||||
failed_websockets = []
|
||||
|
||||
for websocket in connections:
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
successful_sends += 1
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send message to WebSocket",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
failed_websockets.append(websocket)
|
||||
|
||||
# Clean up failed connections
|
||||
if failed_websockets:
|
||||
async with self._lock:
|
||||
for ws in failed_websockets:
|
||||
ws_id = id(ws)
|
||||
self._connections[job_id].pop(ws_id, None)
|
||||
|
||||
if successful_sends > 0:
|
||||
logger.info("Broadcasted message to WebSocket clients",
|
||||
job_id=job_id,
|
||||
message_type=message.get('type'),
|
||||
successful_sends=successful_sends,
|
||||
failed_sends=len(failed_websockets))
|
||||
|
||||
return successful_sends
|
||||
|
||||
def get_connection_count(self, job_id: str) -> int:
|
||||
"""Get the number of active connections for a job"""
|
||||
return len(self._connections.get(job_id, {}))
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
websocket_manager = WebSocketConnectionManager()
|
||||
Reference in New Issue
Block a user