149 lines
4.9 KiB
Python
149 lines
4.9 KiB
Python
|
|
"""
|
||
|
|
RabbitMQ Event Consumer for WebSocket Broadcasting
|
||
|
|
Listens to training events from RabbitMQ and broadcasts them to WebSocket clients
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
from typing import Dict, Set
|
||
|
|
import structlog
|
||
|
|
|
||
|
|
from app.websocket.manager import websocket_manager
|
||
|
|
from app.services.training_events import training_publisher
|
||
|
|
|
||
|
|
logger = structlog.get_logger()
|
||
|
|
|
||
|
|
# Track active consumers
|
||
|
|
_active_consumers: Set[asyncio.Task] = set()
|
||
|
|
|
||
|
|
|
||
|
|
async def handle_training_event(message) -> None:
|
||
|
|
"""
|
||
|
|
Handle incoming RabbitMQ training events and broadcast to WebSocket clients.
|
||
|
|
This is the bridge between RabbitMQ and WebSocket.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
# Parse message
|
||
|
|
body = message.body.decode()
|
||
|
|
data = json.loads(body)
|
||
|
|
|
||
|
|
event_type = data.get('event_type', 'unknown')
|
||
|
|
event_data = data.get('data', {})
|
||
|
|
job_id = event_data.get('job_id')
|
||
|
|
|
||
|
|
if not job_id:
|
||
|
|
logger.warning("Received event without job_id, skipping", event_type=event_type)
|
||
|
|
await message.ack()
|
||
|
|
return
|
||
|
|
|
||
|
|
logger.info("Received training event from RabbitMQ",
|
||
|
|
job_id=job_id,
|
||
|
|
event_type=event_type,
|
||
|
|
progress=event_data.get('progress'))
|
||
|
|
|
||
|
|
# Map RabbitMQ event types to WebSocket message types
|
||
|
|
ws_message_type = _map_event_type(event_type)
|
||
|
|
|
||
|
|
# Create WebSocket message
|
||
|
|
ws_message = {
|
||
|
|
"type": ws_message_type,
|
||
|
|
"job_id": job_id,
|
||
|
|
"timestamp": data.get('timestamp'),
|
||
|
|
"data": event_data
|
||
|
|
}
|
||
|
|
|
||
|
|
# Broadcast to all WebSocket clients for this job
|
||
|
|
sent_count = await websocket_manager.broadcast(job_id, ws_message)
|
||
|
|
|
||
|
|
logger.info("Broadcasted event to WebSocket clients",
|
||
|
|
job_id=job_id,
|
||
|
|
event_type=event_type,
|
||
|
|
ws_message_type=ws_message_type,
|
||
|
|
clients_notified=sent_count)
|
||
|
|
|
||
|
|
# Always acknowledge the message to avoid infinite redelivery loops
|
||
|
|
# Progress events (started, progress, product_completed) are ephemeral and don't need redelivery
|
||
|
|
# Final events (completed, failed) should always be acknowledged
|
||
|
|
await message.ack()
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("Error handling training event",
|
||
|
|
error=str(e),
|
||
|
|
exc_info=True)
|
||
|
|
# Always acknowledge even on error to avoid infinite redelivery loops
|
||
|
|
# The event is logged so we can debug issues
|
||
|
|
try:
|
||
|
|
await message.ack()
|
||
|
|
except:
|
||
|
|
pass # Message already gone or connection closed
|
||
|
|
|
||
|
|
|
||
|
|
def _map_event_type(rabbitmq_event_type: str) -> str:
|
||
|
|
"""Map RabbitMQ event types to WebSocket message types"""
|
||
|
|
mapping = {
|
||
|
|
"training.started": "started",
|
||
|
|
"training.progress": "progress",
|
||
|
|
"training.step.completed": "step_completed",
|
||
|
|
"training.product.completed": "product_completed",
|
||
|
|
"training.completed": "completed",
|
||
|
|
"training.failed": "failed",
|
||
|
|
}
|
||
|
|
return mapping.get(rabbitmq_event_type, "unknown")
|
||
|
|
|
||
|
|
|
||
|
|
async def setup_websocket_event_consumer() -> bool:
|
||
|
|
"""
|
||
|
|
Set up a global RabbitMQ consumer that listens to all training events
|
||
|
|
and broadcasts them to connected WebSocket clients.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
# Ensure publisher is connected
|
||
|
|
if not training_publisher.connected:
|
||
|
|
logger.info("Connecting training publisher for WebSocket event consumer")
|
||
|
|
success = await training_publisher.connect()
|
||
|
|
if not success:
|
||
|
|
logger.error("Failed to connect training publisher")
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Create a unique queue for WebSocket broadcasting
|
||
|
|
queue_name = "training_websocket_broadcast"
|
||
|
|
|
||
|
|
logger.info("Setting up WebSocket event consumer", queue_name=queue_name)
|
||
|
|
|
||
|
|
# Subscribe to all training events (routing key: training.#)
|
||
|
|
success = await training_publisher.consume_events(
|
||
|
|
exchange_name="training.events",
|
||
|
|
queue_name=queue_name,
|
||
|
|
routing_key="training.#", # Listen to all training events (multi-level)
|
||
|
|
callback=handle_training_event
|
||
|
|
)
|
||
|
|
|
||
|
|
if success:
|
||
|
|
logger.info("WebSocket event consumer set up successfully")
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
logger.error("Failed to set up WebSocket event consumer")
|
||
|
|
return False
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("Error setting up WebSocket event consumer",
|
||
|
|
error=str(e),
|
||
|
|
exc_info=True)
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
async def cleanup_websocket_consumers() -> None:
|
||
|
|
"""Clean up WebSocket event consumers"""
|
||
|
|
logger.info("Cleaning up WebSocket event consumers")
|
||
|
|
|
||
|
|
for task in _active_consumers:
|
||
|
|
if not task.done():
|
||
|
|
task.cancel()
|
||
|
|
try:
|
||
|
|
await task
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
_active_consumers.clear()
|
||
|
|
logger.info("WebSocket event consumers cleaned up")
|