""" WebSocket Operations for Training Service Simple WebSocket endpoint that connects clients and receives broadcasts from RabbitMQ """ from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Path, Query import structlog from app.websocket.manager import websocket_manager from shared.auth.jwt_handler import JWTHandler from app.core.config import settings from app.services.training_service import EnhancedTrainingService from shared.database.base import create_database_manager logger = structlog.get_logger() router = APIRouter(tags=["websocket"]) def get_enhanced_training_service(): """Create EnhancedTrainingService instance""" database_manager = create_database_manager(settings.DATABASE_URL, "training-service") return EnhancedTrainingService(database_manager) @router.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live") async def training_progress_websocket( websocket: WebSocket, tenant_id: str = Path(..., description="Tenant ID"), job_id: str = Path(..., description="Job ID"), token: str = Query(..., description="Authentication token") ): """ WebSocket endpoint for real-time training progress updates. This endpoint: 1. Validates the authentication token 2. Accepts the WebSocket connection 3. Keeps the connection alive 4. Receives broadcasts from RabbitMQ (via WebSocket manager) """ # Validate token jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM) try: payload = jwt_handler.verify_token(token) if not payload: await websocket.close(code=1008, reason="Invalid token") logger.warning("WebSocket connection rejected - invalid token", job_id=job_id, tenant_id=tenant_id) return user_id = payload.get('user_id') if not user_id: await websocket.close(code=1008, reason="Invalid token payload") logger.warning("WebSocket connection rejected - no user_id in token", job_id=job_id, tenant_id=tenant_id) return logger.info("WebSocket authentication successful", user_id=user_id, tenant_id=tenant_id, job_id=job_id) except Exception as e: await websocket.close(code=1008, reason="Authentication failed") logger.warning("WebSocket authentication failed", job_id=job_id, tenant_id=tenant_id, error=str(e)) return # Connect to WebSocket manager await websocket_manager.connect(job_id, websocket) # Helper function to send current job status async def send_current_status(): """Fetch and send the current job status to the client""" try: training_service = get_enhanced_training_service() status_info = await training_service.get_training_status(job_id) if status_info and not status_info.get("error"): # Map status to WebSocket message type ws_type = "progress" if status_info.get("status") == "completed": ws_type = "completed" elif status_info.get("status") == "failed": ws_type = "failed" await websocket.send_json({ "type": ws_type, "job_id": job_id, "data": { "progress": status_info.get("progress", 0), "current_step": status_info.get("current_step"), "status": status_info.get("status"), "products_total": status_info.get("products_total", 0), "products_completed": status_info.get("products_completed", 0), "products_failed": status_info.get("products_failed", 0), "estimated_time_remaining_seconds": status_info.get("estimated_time_remaining_seconds"), "message": status_info.get("message") } }) logger.info("Sent current job status to client", job_id=job_id, status=status_info.get("status"), progress=status_info.get("progress")) except Exception as e: logger.error("Failed to send current job status", job_id=job_id, error=str(e)) try: # Send connection confirmation await websocket.send_json({ "type": "connected", "job_id": job_id, "message": "Connected to training progress stream" }) # Immediately send current job status after connection # This handles the race condition where training completes before WebSocket connects await send_current_status() # Keep connection alive and handle client messages ping_count = 0 while True: try: # Receive messages from client (ping, get_status, etc.) data = await websocket.receive_text() # Handle ping/pong if data == "ping": await websocket.send_text("pong") ping_count += 1 logger.debug("WebSocket ping/pong", job_id=job_id, ping_count=ping_count, connection_healthy=True) # Handle get_status request elif data == "get_status": await send_current_status() logger.info("Status requested by client", job_id=job_id) except WebSocketDisconnect: logger.info("Client disconnected", job_id=job_id) break except Exception as e: logger.error("Error in WebSocket message loop", job_id=job_id, error=str(e)) break finally: # Disconnect from manager await websocket_manager.disconnect(job_id, websocket) logger.info("WebSocket connection closed", job_id=job_id, tenant_id=tenant_id)