# services/training/app/api/websocket.py """ WebSocket endpoints for real-time training progress updates """ import json import asyncio import logging from typing import Dict, Any from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException from fastapi.routing import APIRouter from app.services.messaging import training_publisher from shared.auth.decorators import ( get_current_user_dep, get_current_tenant_id_dep ) logger = logging.getLogger(__name__) # Create WebSocket router websocket_router = APIRouter() class ConnectionManager: """Manage WebSocket connections for training progress""" def __init__(self): self.active_connections: Dict[str, Dict[str, WebSocket]] = {} # Structure: {job_id: {connection_id: websocket}} async def connect(self, websocket: WebSocket, job_id: str, connection_id: str): """Accept WebSocket connection and register it""" await websocket.accept() if job_id not in self.active_connections: self.active_connections[job_id] = {} self.active_connections[job_id][connection_id] = websocket logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}") def disconnect(self, job_id: str, connection_id: str): """Remove WebSocket connection""" if job_id in self.active_connections: self.active_connections[job_id].pop(connection_id, None) if not self.active_connections[job_id]: del self.active_connections[job_id] logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}") async def send_to_job(self, job_id: str, message: dict): """Send message to all connections for a specific job""" if job_id not in self.active_connections: return # Send to all connections for this job disconnected_connections = [] for connection_id, websocket in self.active_connections[job_id].items(): try: await websocket.send_json(message) except Exception as e: logger.warning(f"Failed to send message to connection {connection_id}: {e}") disconnected_connections.append(connection_id) # Clean up disconnected connections for connection_id in disconnected_connections: self.disconnect(job_id, connection_id) # Global connection manager connection_manager = ConnectionManager() @websocket_router.websocket("/tenants/{tenant_id}/training/jobs/{job_id}/live") async def training_progress_websocket( websocket: WebSocket, tenant_id: str, job_id: str ): """ WebSocket endpoint for real-time training progress updates Message format sent to client: { "type": "progress" | "completed" | "failed" | "step_completed", "job_id": "job_123", "timestamp": "2025-07-30T19:08:53Z", "data": { "progress": 45, "current_step": "Training model for bread", "current_product": "bread", "products_completed": 2, "products_total": 5, "estimated_time_remaining_minutes": 8 } } """ connection_id = f"{tenant_id}_{id(websocket)}" # Accept connection await connection_manager.connect(websocket, job_id, connection_id) # Set up RabbitMQ consumer for this job consumer_task = None try: # Start RabbitMQ consumer consumer_task = asyncio.create_task( setup_rabbitmq_consumer_for_job(job_id, tenant_id) ) # Send initial status if available try: # You can fetch current job status from database here initial_status = await get_current_job_status(job_id, tenant_id) if initial_status: await websocket.send_json({ "type": "initial_status", "job_id": job_id, "data": initial_status }) except Exception as e: logger.warning(f"Failed to send initial status: {e}") # Keep connection alive and handle client messages while True: try: # Wait for client ping or other messages message = await websocket.receive_text() if message == "ping": await websocket.send_text("pong") elif message == "get_status": # Send current status on demand current_status = await get_current_job_status(job_id, tenant_id) if current_status: await websocket.send_json({ "type": "current_status", "job_id": job_id, "data": current_status }) except WebSocketDisconnect: logger.info(f"WebSocket disconnected for job {job_id}") break except Exception as e: logger.error(f"WebSocket error for job {job_id}: {e}") break finally: # Clean up connection_manager.disconnect(job_id, connection_id) if consumer_task and not consumer_task.done(): consumer_task.cancel() try: await consumer_task except asyncio.CancelledError: pass async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str): """Set up RabbitMQ consumer to listen for training events for a specific job""" try: # Create a unique queue for this WebSocket connection queue_name = f"websocket_training_{job_id}_{tenant_id}" async def handle_training_message(message): """Handle incoming RabbitMQ messages and forward to WebSocket""" try: # Parse the message body = message.body.decode() data = json.loads(body) # Extract event data event_type = data.get("event_type", "unknown") event_data = data.get("data", {}) # Only process messages for this specific job if event_data.get("job_id") != job_id: await message.ack() return # Transform RabbitMQ message to WebSocket message format websocket_message = { "type": map_event_type_to_websocket_type(event_type), "job_id": job_id, "timestamp": data.get("timestamp"), "data": event_data } # Send to all WebSocket connections for this job await connection_manager.send_to_job(job_id, websocket_message) # Acknowledge the message await message.ack() logger.debug(f"Forwarded training event to WebSocket: {event_type}") except Exception as e: logger.error(f"Error handling training message for WebSocket: {e}") await message.nack(requeue=False) # Subscribe to training events await training_publisher.consume_events( exchange_name="training.events", queue_name=queue_name, routing_key="training.*", # Listen to all training events callback=handle_training_message ) except Exception as e: logger.error(f"Failed to set up RabbitMQ consumer for WebSocket: {e}") def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str: """Map RabbitMQ event types to WebSocket message types""" mapping = { "training.started": "started", "training.progress": "progress", "training.completed": "completed", "training.failed": "failed", "training.cancelled": "cancelled", "training.step.completed": "step_completed", "training.product.started": "product_started", "training.product.completed": "product_completed", "training.product.failed": "product_failed", "training.model.trained": "model_trained", "training.data.validation.started": "validation_started", "training.data.validation.completed": "validation_completed" } return mapping.get(rabbitmq_event_type, "unknown") async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]: """Get current job status from database or cache""" try: # This should query your database for current job status # For now, return a placeholder - implement based on your database schema from app.core.database import get_db_session from app.models.training import ModelTrainingLog # Assuming you have this model async with get_db_session() as db: # Query your training job status # This is a placeholder - adjust based on your actual database models pass # Placeholder return - replace with actual database query return { "job_id": job_id, "status": "running", # or "completed", "failed", etc. "progress": 0, "current_step": "Starting...", "started_at": "2025-07-30T19:00:00Z" } except Exception as e: logger.error(f"Failed to get current job status: {e}") return None