diff --git a/services/training/app/api/websocket.py b/services/training/app/api/websocket.py index 60184651..ad72c58b 100644 --- a/services/training/app/api/websocket.py +++ b/services/training/app/api/websocket.py @@ -82,31 +82,23 @@ async def training_progress_websocket( tenant_id: str, job_id: str ): - """ - FIXED WebSocket endpoint for real-time training progress updates - Prevents premature disconnection during training - """ connection_id = f"{tenant_id}_{id(websocket)}" - # Accept connection await connection_manager.connect(websocket, job_id, connection_id) logger.info(f"WebSocket connection established for job {job_id}") - # Set up RabbitMQ consumer for this job consumer_task = None - training_completed = False # Track training completion + training_completed = False try: # Start RabbitMQ consumer - logger.info(f"Setting up RabbitMQ consumer for job {job_id}") consumer_task = asyncio.create_task( setup_rabbitmq_consumer_for_job(job_id, tenant_id) ) - # Give consumer time to set up await asyncio.sleep(0.5) - # Send initial status if available + # Send initial status try: initial_status = await get_current_job_status(job_id, tenant_id) if initial_status: @@ -115,39 +107,50 @@ async def training_progress_websocket( "job_id": job_id, "data": initial_status }) - logger.info(f"Sent initial status for job {job_id}") except Exception as e: logger.warning(f"Failed to send initial status: {e}") - # Keep connection alive - IMPROVED ERROR HANDLING last_activity = asyncio.get_event_loop().time() while not training_completed: try: - # Wait for client messages with timeout + # FIXED: Use receive() instead of receive_text() try: - message = await asyncio.wait_for(websocket.receive_text(), timeout=30.0) + data = await asyncio.wait_for(websocket.receive(), timeout=30.0) last_activity = asyncio.get_event_loop().time() - if message == "ping": - await websocket.send_text("pong") - logger.debug(f"Ping received from job {job_id}") - elif message == "get_status": - current_status = await get_current_job_status(job_id, tenant_id) - if current_status: - await websocket.send_json({ - "type": "current_status", - "job_id": job_id, - "data": current_status - }) - elif message == "close": - logger.info(f"Client requested connection close for job {job_id}") + # Handle different message types + if data["type"] == "websocket.receive": + if "text" in data: + message_text = data["text"] + if message_text == "ping": + await websocket.send_text("pong") + logger.debug(f"Text ping received from job {job_id}") + elif message_text == "get_status": + current_status = await get_current_job_status(job_id, tenant_id) + if current_status: + await websocket.send_json({ + "type": "current_status", + "job_id": job_id, + "data": current_status + }) + elif message_text == "close": + logger.info(f"Client requested connection close for job {job_id}") + break + + elif "bytes" in data: + # Handle binary messages (WebSocket ping frames) + await websocket.send_text("pong") + logger.debug(f"Binary ping received for job {job_id}") + + elif data["type"] == "websocket.disconnect": + logger.info(f"WebSocket disconnect message received for job {job_id}") break except asyncio.TimeoutError: # No message received in 30 seconds - send heartbeat current_time = asyncio.get_event_loop().time() - if current_time - last_activity > 60: # 60 seconds of inactivity + if current_time - last_activity > 60: logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat") try: @@ -165,7 +168,11 @@ async def training_progress_websocket( break except Exception as e: logger.error(f"WebSocket error for job {job_id}: {e}") - # Don't break immediately - try to recover + # Check if it's the specific "cannot call receive" error + if "Cannot call" in str(e) and "disconnect message" in str(e): + logger.error(f"FastAPI WebSocket disconnect error - connection already closed") + break + # Don't break immediately for other errors - try to recover await asyncio.sleep(1) logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}") @@ -174,19 +181,15 @@ async def training_progress_websocket( logger.error(f"Critical WebSocket error for job {job_id}: {e}") finally: - # IMPROVED CLEANUP - Don't cancel consumer unless truly disconnecting logger.info(f"Cleaning up WebSocket connection for job {job_id}") connection_manager.disconnect(job_id, connection_id) - # Only cancel consumer if we're truly done (not just a temporary error) if consumer_task and not consumer_task.done(): if training_completed: logger.info(f"Training completed, cancelling consumer for job {job_id}") consumer_task.cancel() else: logger.warning(f"WebSocket disconnected but training not completed for job {job_id}") - # Let the consumer continue running for other potential connections - # Don't cancel it unless we're sure the job is done try: await consumer_task diff --git a/tests/test_onboarding_flow.sh b/tests/test_onboarding_flow.sh index bf16c94e..4c61f6c4 100755 --- a/tests/test_onboarding_flow.sh +++ b/tests/test_onboarding_flow.sh @@ -313,6 +313,29 @@ req.on('upgrade', (res, socket, head) => { } } }); + + function createTextFrame(text) { + const payload = Buffer.from(text, 'utf8'); + const payloadLength = payload.length; + + let frame; + if (payloadLength < 126) { + frame = Buffer.allocUnsafe(2 + payloadLength); + frame[0] = 0x81; // Text frame, FIN=1 + frame[1] = payloadLength; + payload.copy(frame, 2); + } else if (payloadLength < 65536) { + frame = Buffer.allocUnsafe(4 + payloadLength); + frame[0] = 0x81; + frame[1] = 126; + frame.writeUInt16BE(payloadLength, 2); + payload.copy(frame, 4); + } else { + throw new Error('Payload too large'); + } + + return frame; + } // Enhanced message processing function function processTrainingMessage(message, timestamp) { @@ -443,13 +466,15 @@ req.on('upgrade', (res, socket, head) => { const pingInterval = setInterval(() => { if (socket.writable && !jobCompleted) { try { - const pingFrame = Buffer.from([0x89, 0x00]); - socket.write(pingFrame); + // Send JSON ping message instead of binary frame + const pingMessage = JSON.stringify({ type: 'ping' }); + const textFrame = createTextFrame(pingMessage); + socket.write(textFrame); } catch (e) { // Ignore ping errors } } - }, 5000); // Ping every 5 seconds + }, 5000); // Heartbeat check - ensure we're still receiving messages const heartbeatInterval = setInterval(() => {