""" RabbitMQ Event Consumer for WebSocket Broadcasting Listens to training events from RabbitMQ and broadcasts them to WebSocket clients """ import asyncio import json from typing import Dict, Set import structlog from app.websocket.manager import websocket_manager from app.services.training_events import training_publisher logger = structlog.get_logger() # Track active consumers _active_consumers: Set[asyncio.Task] = set() async def handle_training_event(message) -> None: """ Handle incoming RabbitMQ training events and broadcast to WebSocket clients. This is the bridge between RabbitMQ and WebSocket. """ try: # Parse message body = message.body.decode() data = json.loads(body) event_type = data.get('event_type', 'unknown') event_data = data.get('data', {}) job_id = event_data.get('job_id') if not job_id: logger.warning("Received event without job_id, skipping", event_type=event_type) await message.ack() return logger.info("Received training event from RabbitMQ", job_id=job_id, event_type=event_type, progress=event_data.get('progress')) # Map RabbitMQ event types to WebSocket message types ws_message_type = _map_event_type(event_type) # Create WebSocket message ws_message = { "type": ws_message_type, "job_id": job_id, "timestamp": data.get('timestamp'), "data": event_data } # Broadcast to all WebSocket clients for this job sent_count = await websocket_manager.broadcast(job_id, ws_message) logger.info("Broadcasted event to WebSocket clients", job_id=job_id, event_type=event_type, ws_message_type=ws_message_type, clients_notified=sent_count) # Always acknowledge the message to avoid infinite redelivery loops # Progress events (started, progress, product_completed) are ephemeral and don't need redelivery # Final events (completed, failed) should always be acknowledged await message.ack() except Exception as e: logger.error("Error handling training event", error=str(e), exc_info=True) # Always acknowledge even on error to avoid infinite redelivery loops # The event is logged so we can debug issues try: await message.ack() except: pass # Message already gone or connection closed def _map_event_type(rabbitmq_event_type: str) -> str: """Map RabbitMQ event types to WebSocket message types""" mapping = { "training.started": "started", "training.progress": "progress", "training.step.completed": "step_completed", "training.product.completed": "product_completed", "training.completed": "completed", "training.failed": "failed", } return mapping.get(rabbitmq_event_type, "unknown") async def setup_websocket_event_consumer() -> bool: """ Set up a global RabbitMQ consumer that listens to all training events and broadcasts them to connected WebSocket clients. """ try: # Ensure publisher is connected if not training_publisher.connected: logger.info("Connecting training publisher for WebSocket event consumer") success = await training_publisher.connect() if not success: logger.error("Failed to connect training publisher") return False # Create a unique queue for WebSocket broadcasting queue_name = "training_websocket_broadcast" logger.info("Setting up WebSocket event consumer", queue_name=queue_name) # Subscribe to all training events (routing key: training.#) success = await training_publisher.consume_events( exchange_name="training.events", queue_name=queue_name, routing_key="training.#", # Listen to all training events (multi-level) callback=handle_training_event ) if success: logger.info("WebSocket event consumer set up successfully") return True else: logger.error("Failed to set up WebSocket event consumer") return False except Exception as e: logger.error("Error setting up WebSocket event consumer", error=str(e), exc_info=True) return False async def cleanup_websocket_consumers() -> None: """Clean up WebSocket event consumers""" logger.info("Cleaning up WebSocket event consumers") for task in _active_consumers: if not task.done(): task.cancel() try: await task except asyncio.CancelledError: pass _active_consumers.clear() logger.info("WebSocket event consumers cleaned up")