REFACTOR external service and improve websocket training

This commit is contained in:
Urtzi Alfaro
2025-10-09 14:11:02 +02:00
parent 7c72f83c51
commit 3c689b4f98
111 changed files with 13289 additions and 2374 deletions

View File

@@ -0,0 +1,11 @@
"""WebSocket support for training service"""
from app.websocket.manager import websocket_manager, WebSocketConnectionManager
from app.websocket.events import setup_websocket_event_consumer, cleanup_websocket_consumers
__all__ = [
'websocket_manager',
'WebSocketConnectionManager',
'setup_websocket_event_consumer',
'cleanup_websocket_consumers'
]

View File

@@ -0,0 +1,148 @@
"""
RabbitMQ Event Consumer for WebSocket Broadcasting
Listens to training events from RabbitMQ and broadcasts them to WebSocket clients
"""
import asyncio
import json
from typing import Dict, Set
import structlog
from app.websocket.manager import websocket_manager
from app.services.training_events import training_publisher
logger = structlog.get_logger()
# Track active consumers
_active_consumers: Set[asyncio.Task] = set()
async def handle_training_event(message) -> None:
"""
Handle incoming RabbitMQ training events and broadcast to WebSocket clients.
This is the bridge between RabbitMQ and WebSocket.
"""
try:
# Parse message
body = message.body.decode()
data = json.loads(body)
event_type = data.get('event_type', 'unknown')
event_data = data.get('data', {})
job_id = event_data.get('job_id')
if not job_id:
logger.warning("Received event without job_id, skipping", event_type=event_type)
await message.ack()
return
logger.info("Received training event from RabbitMQ",
job_id=job_id,
event_type=event_type,
progress=event_data.get('progress'))
# Map RabbitMQ event types to WebSocket message types
ws_message_type = _map_event_type(event_type)
# Create WebSocket message
ws_message = {
"type": ws_message_type,
"job_id": job_id,
"timestamp": data.get('timestamp'),
"data": event_data
}
# Broadcast to all WebSocket clients for this job
sent_count = await websocket_manager.broadcast(job_id, ws_message)
logger.info("Broadcasted event to WebSocket clients",
job_id=job_id,
event_type=event_type,
ws_message_type=ws_message_type,
clients_notified=sent_count)
# Always acknowledge the message to avoid infinite redelivery loops
# Progress events (started, progress, product_completed) are ephemeral and don't need redelivery
# Final events (completed, failed) should always be acknowledged
await message.ack()
except Exception as e:
logger.error("Error handling training event",
error=str(e),
exc_info=True)
# Always acknowledge even on error to avoid infinite redelivery loops
# The event is logged so we can debug issues
try:
await message.ack()
except:
pass # Message already gone or connection closed
def _map_event_type(rabbitmq_event_type: str) -> str:
"""Map RabbitMQ event types to WebSocket message types"""
mapping = {
"training.started": "started",
"training.progress": "progress",
"training.step.completed": "step_completed",
"training.product.completed": "product_completed",
"training.completed": "completed",
"training.failed": "failed",
}
return mapping.get(rabbitmq_event_type, "unknown")
async def setup_websocket_event_consumer() -> bool:
"""
Set up a global RabbitMQ consumer that listens to all training events
and broadcasts them to connected WebSocket clients.
"""
try:
# Ensure publisher is connected
if not training_publisher.connected:
logger.info("Connecting training publisher for WebSocket event consumer")
success = await training_publisher.connect()
if not success:
logger.error("Failed to connect training publisher")
return False
# Create a unique queue for WebSocket broadcasting
queue_name = "training_websocket_broadcast"
logger.info("Setting up WebSocket event consumer", queue_name=queue_name)
# Subscribe to all training events (routing key: training.#)
success = await training_publisher.consume_events(
exchange_name="training.events",
queue_name=queue_name,
routing_key="training.#", # Listen to all training events (multi-level)
callback=handle_training_event
)
if success:
logger.info("WebSocket event consumer set up successfully")
return True
else:
logger.error("Failed to set up WebSocket event consumer")
return False
except Exception as e:
logger.error("Error setting up WebSocket event consumer",
error=str(e),
exc_info=True)
return False
async def cleanup_websocket_consumers() -> None:
"""Clean up WebSocket event consumers"""
logger.info("Cleaning up WebSocket event consumers")
for task in _active_consumers:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
_active_consumers.clear()
logger.info("WebSocket event consumers cleaned up")

View File

@@ -0,0 +1,120 @@
"""
WebSocket Connection Manager for Training Service
Manages WebSocket connections and broadcasts RabbitMQ events to connected clients
"""
import asyncio
import json
from typing import Dict, Set
from fastapi import WebSocket
import structlog
logger = structlog.get_logger()
class WebSocketConnectionManager:
"""
Simple WebSocket connection manager.
Manages connections per job_id and broadcasts messages to all connected clients.
"""
def __init__(self):
# Structure: {job_id: {websocket_id: WebSocket}}
self._connections: Dict[str, Dict[int, WebSocket]] = {}
self._lock = asyncio.Lock()
# Store latest event for each job to provide initial state
self._latest_events: Dict[str, dict] = {}
async def connect(self, job_id: str, websocket: WebSocket) -> None:
"""Register a new WebSocket connection for a job"""
await websocket.accept()
async with self._lock:
if job_id not in self._connections:
self._connections[job_id] = {}
ws_id = id(websocket)
self._connections[job_id][ws_id] = websocket
# Send initial state if available
if job_id in self._latest_events:
try:
await websocket.send_json({
"type": "initial_state",
"job_id": job_id,
"data": self._latest_events[job_id]
})
except Exception as e:
logger.warning("Failed to send initial state to new connection", error=str(e))
logger.info("WebSocket connected",
job_id=job_id,
websocket_id=ws_id,
total_connections=len(self._connections[job_id]))
async def disconnect(self, job_id: str, websocket: WebSocket) -> None:
"""Remove a WebSocket connection"""
async with self._lock:
if job_id in self._connections:
ws_id = id(websocket)
self._connections[job_id].pop(ws_id, None)
# Clean up empty job connections
if not self._connections[job_id]:
del self._connections[job_id]
logger.info("WebSocket disconnected",
job_id=job_id,
websocket_id=ws_id,
remaining_connections=len(self._connections.get(job_id, {})))
async def broadcast(self, job_id: str, message: dict) -> int:
"""
Broadcast a message to all connections for a specific job.
Returns the number of successful broadcasts.
"""
# Store the latest event for this job to provide initial state to new connections
if message.get('type') != 'initial_state': # Don't store initial_state messages
self._latest_events[job_id] = message
if job_id not in self._connections:
logger.debug("No active connections for job", job_id=job_id)
return 0
connections = list(self._connections[job_id].values())
successful_sends = 0
failed_websockets = []
for websocket in connections:
try:
await websocket.send_json(message)
successful_sends += 1
except Exception as e:
logger.warning("Failed to send message to WebSocket",
job_id=job_id,
error=str(e))
failed_websockets.append(websocket)
# Clean up failed connections
if failed_websockets:
async with self._lock:
for ws in failed_websockets:
ws_id = id(ws)
self._connections[job_id].pop(ws_id, None)
if successful_sends > 0:
logger.info("Broadcasted message to WebSocket clients",
job_id=job_id,
message_type=message.get('type'),
successful_sends=successful_sends,
failed_sends=len(failed_websockets))
return successful_sends
def get_connection_count(self, job_id: str) -> int:
"""Get the number of active connections for a job"""
return len(self._connections.get(job_id, {}))
# Global singleton instance
websocket_manager = WebSocketConnectionManager()