REFACTOR external service and improve websocket training
This commit is contained in:
11
services/training/app/websocket/__init__.py
Normal file
11
services/training/app/websocket/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""WebSocket support for training service"""
|
||||
|
||||
from app.websocket.manager import websocket_manager, WebSocketConnectionManager
|
||||
from app.websocket.events import setup_websocket_event_consumer, cleanup_websocket_consumers
|
||||
|
||||
__all__ = [
|
||||
'websocket_manager',
|
||||
'WebSocketConnectionManager',
|
||||
'setup_websocket_event_consumer',
|
||||
'cleanup_websocket_consumers'
|
||||
]
|
||||
148
services/training/app/websocket/events.py
Normal file
148
services/training/app/websocket/events.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
RabbitMQ Event Consumer for WebSocket Broadcasting
|
||||
Listens to training events from RabbitMQ and broadcasts them to WebSocket clients
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, Set
|
||||
import structlog
|
||||
|
||||
from app.websocket.manager import websocket_manager
|
||||
from app.services.training_events import training_publisher
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Track active consumers
|
||||
_active_consumers: Set[asyncio.Task] = set()
|
||||
|
||||
|
||||
async def handle_training_event(message) -> None:
|
||||
"""
|
||||
Handle incoming RabbitMQ training events and broadcast to WebSocket clients.
|
||||
This is the bridge between RabbitMQ and WebSocket.
|
||||
"""
|
||||
try:
|
||||
# Parse message
|
||||
body = message.body.decode()
|
||||
data = json.loads(body)
|
||||
|
||||
event_type = data.get('event_type', 'unknown')
|
||||
event_data = data.get('data', {})
|
||||
job_id = event_data.get('job_id')
|
||||
|
||||
if not job_id:
|
||||
logger.warning("Received event without job_id, skipping", event_type=event_type)
|
||||
await message.ack()
|
||||
return
|
||||
|
||||
logger.info("Received training event from RabbitMQ",
|
||||
job_id=job_id,
|
||||
event_type=event_type,
|
||||
progress=event_data.get('progress'))
|
||||
|
||||
# Map RabbitMQ event types to WebSocket message types
|
||||
ws_message_type = _map_event_type(event_type)
|
||||
|
||||
# Create WebSocket message
|
||||
ws_message = {
|
||||
"type": ws_message_type,
|
||||
"job_id": job_id,
|
||||
"timestamp": data.get('timestamp'),
|
||||
"data": event_data
|
||||
}
|
||||
|
||||
# Broadcast to all WebSocket clients for this job
|
||||
sent_count = await websocket_manager.broadcast(job_id, ws_message)
|
||||
|
||||
logger.info("Broadcasted event to WebSocket clients",
|
||||
job_id=job_id,
|
||||
event_type=event_type,
|
||||
ws_message_type=ws_message_type,
|
||||
clients_notified=sent_count)
|
||||
|
||||
# Always acknowledge the message to avoid infinite redelivery loops
|
||||
# Progress events (started, progress, product_completed) are ephemeral and don't need redelivery
|
||||
# Final events (completed, failed) should always be acknowledged
|
||||
await message.ack()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling training event",
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
# Always acknowledge even on error to avoid infinite redelivery loops
|
||||
# The event is logged so we can debug issues
|
||||
try:
|
||||
await message.ack()
|
||||
except:
|
||||
pass # Message already gone or connection closed
|
||||
|
||||
|
||||
def _map_event_type(rabbitmq_event_type: str) -> str:
|
||||
"""Map RabbitMQ event types to WebSocket message types"""
|
||||
mapping = {
|
||||
"training.started": "started",
|
||||
"training.progress": "progress",
|
||||
"training.step.completed": "step_completed",
|
||||
"training.product.completed": "product_completed",
|
||||
"training.completed": "completed",
|
||||
"training.failed": "failed",
|
||||
}
|
||||
return mapping.get(rabbitmq_event_type, "unknown")
|
||||
|
||||
|
||||
async def setup_websocket_event_consumer() -> bool:
|
||||
"""
|
||||
Set up a global RabbitMQ consumer that listens to all training events
|
||||
and broadcasts them to connected WebSocket clients.
|
||||
"""
|
||||
try:
|
||||
# Ensure publisher is connected
|
||||
if not training_publisher.connected:
|
||||
logger.info("Connecting training publisher for WebSocket event consumer")
|
||||
success = await training_publisher.connect()
|
||||
if not success:
|
||||
logger.error("Failed to connect training publisher")
|
||||
return False
|
||||
|
||||
# Create a unique queue for WebSocket broadcasting
|
||||
queue_name = "training_websocket_broadcast"
|
||||
|
||||
logger.info("Setting up WebSocket event consumer", queue_name=queue_name)
|
||||
|
||||
# Subscribe to all training events (routing key: training.#)
|
||||
success = await training_publisher.consume_events(
|
||||
exchange_name="training.events",
|
||||
queue_name=queue_name,
|
||||
routing_key="training.#", # Listen to all training events (multi-level)
|
||||
callback=handle_training_event
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("WebSocket event consumer set up successfully")
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to set up WebSocket event consumer")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error setting up WebSocket event consumer",
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
async def cleanup_websocket_consumers() -> None:
|
||||
"""Clean up WebSocket event consumers"""
|
||||
logger.info("Cleaning up WebSocket event consumers")
|
||||
|
||||
for task in _active_consumers:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
_active_consumers.clear()
|
||||
logger.info("WebSocket event consumers cleaned up")
|
||||
120
services/training/app/websocket/manager.py
Normal file
120
services/training/app/websocket/manager.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
WebSocket Connection Manager for Training Service
|
||||
Manages WebSocket connections and broadcasts RabbitMQ events to connected clients
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, Set
|
||||
from fastapi import WebSocket
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class WebSocketConnectionManager:
|
||||
"""
|
||||
Simple WebSocket connection manager.
|
||||
Manages connections per job_id and broadcasts messages to all connected clients.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Structure: {job_id: {websocket_id: WebSocket}}
|
||||
self._connections: Dict[str, Dict[int, WebSocket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
# Store latest event for each job to provide initial state
|
||||
self._latest_events: Dict[str, dict] = {}
|
||||
|
||||
async def connect(self, job_id: str, websocket: WebSocket) -> None:
|
||||
"""Register a new WebSocket connection for a job"""
|
||||
await websocket.accept()
|
||||
|
||||
async with self._lock:
|
||||
if job_id not in self._connections:
|
||||
self._connections[job_id] = {}
|
||||
|
||||
ws_id = id(websocket)
|
||||
self._connections[job_id][ws_id] = websocket
|
||||
|
||||
# Send initial state if available
|
||||
if job_id in self._latest_events:
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "initial_state",
|
||||
"job_id": job_id,
|
||||
"data": self._latest_events[job_id]
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send initial state to new connection", error=str(e))
|
||||
|
||||
logger.info("WebSocket connected",
|
||||
job_id=job_id,
|
||||
websocket_id=ws_id,
|
||||
total_connections=len(self._connections[job_id]))
|
||||
|
||||
async def disconnect(self, job_id: str, websocket: WebSocket) -> None:
|
||||
"""Remove a WebSocket connection"""
|
||||
async with self._lock:
|
||||
if job_id in self._connections:
|
||||
ws_id = id(websocket)
|
||||
self._connections[job_id].pop(ws_id, None)
|
||||
|
||||
# Clean up empty job connections
|
||||
if not self._connections[job_id]:
|
||||
del self._connections[job_id]
|
||||
|
||||
logger.info("WebSocket disconnected",
|
||||
job_id=job_id,
|
||||
websocket_id=ws_id,
|
||||
remaining_connections=len(self._connections.get(job_id, {})))
|
||||
|
||||
async def broadcast(self, job_id: str, message: dict) -> int:
|
||||
"""
|
||||
Broadcast a message to all connections for a specific job.
|
||||
Returns the number of successful broadcasts.
|
||||
"""
|
||||
# Store the latest event for this job to provide initial state to new connections
|
||||
if message.get('type') != 'initial_state': # Don't store initial_state messages
|
||||
self._latest_events[job_id] = message
|
||||
|
||||
if job_id not in self._connections:
|
||||
logger.debug("No active connections for job", job_id=job_id)
|
||||
return 0
|
||||
|
||||
connections = list(self._connections[job_id].values())
|
||||
successful_sends = 0
|
||||
failed_websockets = []
|
||||
|
||||
for websocket in connections:
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
successful_sends += 1
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send message to WebSocket",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
failed_websockets.append(websocket)
|
||||
|
||||
# Clean up failed connections
|
||||
if failed_websockets:
|
||||
async with self._lock:
|
||||
for ws in failed_websockets:
|
||||
ws_id = id(ws)
|
||||
self._connections[job_id].pop(ws_id, None)
|
||||
|
||||
if successful_sends > 0:
|
||||
logger.info("Broadcasted message to WebSocket clients",
|
||||
job_id=job_id,
|
||||
message_type=message.get('type'),
|
||||
successful_sends=successful_sends,
|
||||
failed_sends=len(failed_websockets))
|
||||
|
||||
return successful_sends
|
||||
|
||||
def get_connection_count(self, job_id: str) -> int:
|
||||
"""Get the number of active connections for a job"""
|
||||
return len(self._connections.get(job_id, {}))
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
websocket_manager = WebSocketConnectionManager()
|
||||
Reference in New Issue
Block a user