""" WebSocket Operations for Training Service Simple WebSocket endpoint that connects clients and receives broadcasts from RabbitMQ """ from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Path, Query import structlog from app.websocket.manager import websocket_manager from shared.auth.jwt_handler import JWTHandler from app.core.config import settings logger = structlog.get_logger() router = APIRouter(tags=["websocket"]) @router.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live") async def training_progress_websocket( websocket: WebSocket, tenant_id: str = Path(..., description="Tenant ID"), job_id: str = Path(..., description="Job ID"), token: str = Query(..., description="Authentication token") ): """ WebSocket endpoint for real-time training progress updates. This endpoint: 1. Validates the authentication token 2. Accepts the WebSocket connection 3. Keeps the connection alive 4. Receives broadcasts from RabbitMQ (via WebSocket manager) """ # Validate token jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM) try: payload = jwt_handler.verify_token(token) if not payload: await websocket.close(code=1008, reason="Invalid token") logger.warning("WebSocket connection rejected - invalid token", job_id=job_id, tenant_id=tenant_id) return user_id = payload.get('user_id') if not user_id: await websocket.close(code=1008, reason="Invalid token payload") logger.warning("WebSocket connection rejected - no user_id in token", job_id=job_id, tenant_id=tenant_id) return logger.info("WebSocket authentication successful", user_id=user_id, tenant_id=tenant_id, job_id=job_id) except Exception as e: await websocket.close(code=1008, reason="Authentication failed") logger.warning("WebSocket authentication failed", job_id=job_id, tenant_id=tenant_id, error=str(e)) return # Connect to WebSocket manager await websocket_manager.connect(job_id, websocket) try: # Send connection confirmation await websocket.send_json({ "type": "connected", "job_id": job_id, "message": "Connected to training progress stream" }) # Keep connection alive and handle client messages ping_count = 0 while True: try: # Receive messages from client (ping, etc.) data = await websocket.receive_text() # Handle ping/pong if data == "ping": await websocket.send_text("pong") ping_count += 1 logger.info("WebSocket ping/pong", job_id=job_id, ping_count=ping_count, connection_healthy=True) except WebSocketDisconnect: logger.info("Client disconnected", job_id=job_id) break except Exception as e: logger.error("Error in WebSocket message loop", job_id=job_id, error=str(e)) break finally: # Disconnect from manager await websocket_manager.disconnect(job_id, websocket) logger.info("WebSocket connection closed", job_id=job_id, tenant_id=tenant_id)