# services/training/app/api/websocket.py """ WebSocket endpoints for real-time training progress updates """ import json import asyncio from typing import Dict, Any from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException from fastapi.routing import APIRouter import structlog logger = structlog.get_logger(__name__) from app.services.messaging import training_publisher from shared.auth.decorators import ( get_current_user_dep, get_current_tenant_id_dep ) # Create WebSocket router websocket_router = APIRouter() class ConnectionManager: """Manage WebSocket connections for training progress""" def __init__(self): self.active_connections: Dict[str, Dict[str, WebSocket]] = {} # Structure: {job_id: {connection_id: websocket}} async def connect(self, websocket: WebSocket, job_id: str, connection_id: str): """Accept WebSocket connection and register it""" await websocket.accept() if job_id not in self.active_connections: self.active_connections[job_id] = {} self.active_connections[job_id][connection_id] = websocket logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}") def disconnect(self, job_id: str, connection_id: str): """Remove WebSocket connection""" if job_id in self.active_connections: self.active_connections[job_id].pop(connection_id, None) if not self.active_connections[job_id]: del self.active_connections[job_id] logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}") async def send_to_job(self, job_id: str, message: dict): """Send message to all connections for a specific job with better error handling""" if job_id not in self.active_connections: logger.debug(f"No active connections for job {job_id}") return # Send to all connections for this job disconnected_connections = [] for connection_id, websocket in self.active_connections[job_id].items(): try: await websocket.send_json(message) logger.debug(f"📤 Sent {message.get('type', 'unknown')} to connection {connection_id}") except Exception as e: logger.warning(f"Failed to send message to connection {connection_id}: {e}") disconnected_connections.append(connection_id) # Clean up disconnected connections for connection_id in disconnected_connections: self.disconnect(job_id, connection_id) # Log successful sends active_count = len(self.active_connections.get(job_id, {})) if active_count > 0: logger.info(f"📡 Sent {message.get('type', 'unknown')} message to {active_count} connection(s) for job {job_id}") # Global connection manager connection_manager = ConnectionManager() @websocket_router.websocket("/tenants/{tenant_id}/training/jobs/{job_id}/live") async def training_progress_websocket( websocket: WebSocket, tenant_id: str, job_id: str ): connection_id = f"{tenant_id}_{id(websocket)}" await connection_manager.connect(websocket, job_id, connection_id) logger.info(f"WebSocket connection established for job {job_id}") consumer_task = None training_completed = False try: # Start RabbitMQ consumer consumer_task = asyncio.create_task( setup_rabbitmq_consumer_for_job(job_id, tenant_id) ) await asyncio.sleep(0.5) # Send initial status try: initial_status = await get_current_job_status(job_id, tenant_id) if initial_status: await websocket.send_json({ "type": "initial_status", "job_id": job_id, "data": initial_status }) except Exception as e: logger.warning(f"Failed to send initial status: {e}") last_activity = asyncio.get_event_loop().time() while not training_completed: try: # FIXED: Use receive() instead of receive_text() try: data = await asyncio.wait_for(websocket.receive(), timeout=30.0) last_activity = asyncio.get_event_loop().time() # Handle different message types if data["type"] == "websocket.receive": if "text" in data: message_text = data["text"] if message_text == "ping": await websocket.send_text("pong") logger.debug(f"Text ping received from job {job_id}") elif message_text == "get_status": current_status = await get_current_job_status(job_id, tenant_id) if current_status: await websocket.send_json({ "type": "current_status", "job_id": job_id, "data": current_status }) elif message_text == "close": logger.info(f"Client requested connection close for job {job_id}") break elif "bytes" in data: # Handle binary messages (WebSocket ping frames) await websocket.send_text("pong") logger.debug(f"Binary ping received for job {job_id}") elif data["type"] == "websocket.disconnect": logger.info(f"WebSocket disconnect message received for job {job_id}") break except asyncio.TimeoutError: # No message received in 30 seconds - send heartbeat current_time = asyncio.get_event_loop().time() if current_time - last_activity > 60: logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat") try: await websocket.send_json({ "type": "heartbeat", "job_id": job_id, "timestamp": datetime.utcnow().isoformat() }) except Exception as e: logger.error(f"Failed to send heartbeat for job {job_id}: {e}") break except WebSocketDisconnect: logger.info(f"WebSocket client disconnected for job {job_id}") break except Exception as e: logger.error(f"WebSocket error for job {job_id}: {e}") # Check if it's the specific "cannot call receive" error if "Cannot call" in str(e) and "disconnect message" in str(e): logger.error(f"FastAPI WebSocket disconnect error - connection already closed") break # Don't break immediately for other errors - try to recover await asyncio.sleep(1) logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}") except Exception as e: logger.error(f"Critical WebSocket error for job {job_id}: {e}") finally: logger.info(f"Cleaning up WebSocket connection for job {job_id}") connection_manager.disconnect(job_id, connection_id) if consumer_task and not consumer_task.done(): if training_completed: logger.info(f"Training completed, cancelling consumer for job {job_id}") consumer_task.cancel() else: logger.warning(f"WebSocket disconnected but training not completed for job {job_id}") try: await consumer_task except asyncio.CancelledError: logger.info(f"Consumer task cancelled for job {job_id}") except Exception as e: logger.error(f"Consumer task error for job {job_id}: {e}") async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str): """Set up RabbitMQ consumer to listen for training events for a specific job""" logger.info(f"🚀 Setting up RabbitMQ consumer for job {job_id}") try: # Create a unique queue for this WebSocket connection queue_name = f"websocket_training_{job_id}_{tenant_id}" async def handle_training_message(message): """Handle incoming RabbitMQ messages and forward to WebSocket""" try: # Parse the message body = message.body.decode() data = json.loads(body) logger.debug(f"🔍 Received message for job {job_id}: {data.get('event_type', 'unknown')}") # Extract event data event_type = data.get("event_type", "unknown") event_data = data.get("data", {}) # Only process messages for this specific job message_job_id = event_data.get("job_id") if event_data else None if message_job_id != job_id: logger.debug(f"⏭️ Ignoring message for different job: {message_job_id}") await message.ack() return # Transform RabbitMQ message to WebSocket message format websocket_message = { "type": map_event_type_to_websocket_type(event_type), "job_id": job_id, "timestamp": data.get("timestamp"), "data": event_data } logger.info(f"📤 Forwarding {event_type} message to WebSocket clients for job {job_id}") # Send to all WebSocket connections for this job await connection_manager.send_to_job(job_id, websocket_message) # Check if this is a completion message if event_type in ["training.completed", "training.failed"]: logger.info(f"🎯 Training completion detected for job {job_id}: {event_type}") # Mark training as completed (you might want to store this in a global state) # For now, we'll let the WebSocket handle this through the message # Acknowledge the message await message.ack() logger.debug(f"✅ Successfully processed {event_type} for job {job_id}") except Exception as e: logger.error(f"❌ Error handling training message for job {job_id}: {e}") import traceback logger.error(f"💥 Traceback: {traceback.format_exc()}") await message.nack(requeue=False) # Check if training_publisher is connected if not training_publisher.connected: logger.warning(f"⚠️ Training publisher not connected for job {job_id}, attempting to connect...") success = await training_publisher.connect() if not success: logger.error(f"❌ Failed to connect training_publisher for job {job_id}") return # Subscribe to training events logger.info(f"🔗 Subscribing to training events for job {job_id}") success = await training_publisher.consume_events( exchange_name="training.events", queue_name=queue_name, routing_key="training.*", # Listen to all training events callback=handle_training_message ) if success: logger.info(f"✅ Successfully set up RabbitMQ consumer for job {job_id} (queue: {queue_name})") # Keep the consumer running indefinitely until cancelled try: while True: await asyncio.sleep(10) # Keep consumer alive logger.debug(f"🔄 Consumer heartbeat for job {job_id}") except asyncio.CancelledError: logger.info(f"🛑 Consumer cancelled for job {job_id}") raise except Exception as e: logger.error(f"💥 Consumer error for job {job_id}: {e}") raise else: logger.error(f"❌ Failed to set up RabbitMQ consumer for job {job_id}") except Exception as e: logger.error(f"💥 Exception in setup_rabbitmq_consumer_for_job for job {job_id}: {e}") import traceback logger.error(f"🔥 Traceback: {traceback.format_exc()}") def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str: """Map RabbitMQ event types to WebSocket message types""" mapping = { "training.started": "started", "training.progress": "progress", "training.completed": "completed", # This is the key completion event "training.failed": "failed", # This is also a completion event "training.cancelled": "cancelled", "training.step.completed": "step_completed", "training.product.started": "product_started", "training.product.completed": "product_completed", "training.product.failed": "product_failed", "training.model.trained": "model_trained", "training.data.validation.started": "validation_started", "training.data.validation.completed": "validation_completed" } return mapping.get(rabbitmq_event_type, "unknown") async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]: """Get current job status from database or cache""" try: # This should query your database for current job status # For now, return a placeholder - implement based on your database schema from app.core.database import get_db_session from app.models.training import ModelTrainingLog # Assuming you have this model # async with get_background_db_session() as db: # Query your training job status # This is a placeholder - adjust based on your actual database models # pass # Placeholder return - replace with actual database query return { "job_id": job_id, "status": "running", # or "completed", "failed", etc. "progress": 0, "current_step": "Starting...", "started_at": "2025-07-30T19:00:00Z" } except Exception as e: logger.error(f"Failed to get current job status: {e}") return None