Websocket fix 2

This commit is contained in:
Urtzi Alfaro
2025-08-01 18:13:34 +02:00
parent 81e7ab7432
commit 37938b614f
2 changed files with 64 additions and 36 deletions

View File

@@ -82,31 +82,23 @@ async def training_progress_websocket(
tenant_id: str,
job_id: str
):
"""
FIXED WebSocket endpoint for real-time training progress updates
Prevents premature disconnection during training
"""
connection_id = f"{tenant_id}_{id(websocket)}"
# Accept connection
await connection_manager.connect(websocket, job_id, connection_id)
logger.info(f"WebSocket connection established for job {job_id}")
# Set up RabbitMQ consumer for this job
consumer_task = None
training_completed = False # Track training completion
training_completed = False
try:
# Start RabbitMQ consumer
logger.info(f"Setting up RabbitMQ consumer for job {job_id}")
consumer_task = asyncio.create_task(
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
)
# Give consumer time to set up
await asyncio.sleep(0.5)
# Send initial status if available
# Send initial status
try:
initial_status = await get_current_job_status(job_id, tenant_id)
if initial_status:
@@ -115,39 +107,50 @@ async def training_progress_websocket(
"job_id": job_id,
"data": initial_status
})
logger.info(f"Sent initial status for job {job_id}")
except Exception as e:
logger.warning(f"Failed to send initial status: {e}")
# Keep connection alive - IMPROVED ERROR HANDLING
last_activity = asyncio.get_event_loop().time()
while not training_completed:
try:
# Wait for client messages with timeout
# FIXED: Use receive() instead of receive_text()
try:
message = await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
data = await asyncio.wait_for(websocket.receive(), timeout=30.0)
last_activity = asyncio.get_event_loop().time()
if message == "ping":
await websocket.send_text("pong")
logger.debug(f"Ping received from job {job_id}")
elif message == "get_status":
current_status = await get_current_job_status(job_id, tenant_id)
if current_status:
await websocket.send_json({
"type": "current_status",
"job_id": job_id,
"data": current_status
})
elif message == "close":
logger.info(f"Client requested connection close for job {job_id}")
# Handle different message types
if data["type"] == "websocket.receive":
if "text" in data:
message_text = data["text"]
if message_text == "ping":
await websocket.send_text("pong")
logger.debug(f"Text ping received from job {job_id}")
elif message_text == "get_status":
current_status = await get_current_job_status(job_id, tenant_id)
if current_status:
await websocket.send_json({
"type": "current_status",
"job_id": job_id,
"data": current_status
})
elif message_text == "close":
logger.info(f"Client requested connection close for job {job_id}")
break
elif "bytes" in data:
# Handle binary messages (WebSocket ping frames)
await websocket.send_text("pong")
logger.debug(f"Binary ping received for job {job_id}")
elif data["type"] == "websocket.disconnect":
logger.info(f"WebSocket disconnect message received for job {job_id}")
break
except asyncio.TimeoutError:
# No message received in 30 seconds - send heartbeat
current_time = asyncio.get_event_loop().time()
if current_time - last_activity > 60: # 60 seconds of inactivity
if current_time - last_activity > 60:
logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat")
try:
@@ -165,7 +168,11 @@ async def training_progress_websocket(
break
except Exception as e:
logger.error(f"WebSocket error for job {job_id}: {e}")
# Don't break immediately - try to recover
# Check if it's the specific "cannot call receive" error
if "Cannot call" in str(e) and "disconnect message" in str(e):
logger.error(f"FastAPI WebSocket disconnect error - connection already closed")
break
# Don't break immediately for other errors - try to recover
await asyncio.sleep(1)
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
@@ -174,19 +181,15 @@ async def training_progress_websocket(
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
finally:
# IMPROVED CLEANUP - Don't cancel consumer unless truly disconnecting
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
connection_manager.disconnect(job_id, connection_id)
# Only cancel consumer if we're truly done (not just a temporary error)
if consumer_task and not consumer_task.done():
if training_completed:
logger.info(f"Training completed, cancelling consumer for job {job_id}")
consumer_task.cancel()
else:
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
# Let the consumer continue running for other potential connections
# Don't cancel it unless we're sure the job is done
try:
await consumer_task

View File

@@ -314,6 +314,29 @@ req.on('upgrade', (res, socket, head) => {
}
});
function createTextFrame(text) {
const payload = Buffer.from(text, 'utf8');
const payloadLength = payload.length;
let frame;
if (payloadLength < 126) {
frame = Buffer.allocUnsafe(2 + payloadLength);
frame[0] = 0x81; // Text frame, FIN=1
frame[1] = payloadLength;
payload.copy(frame, 2);
} else if (payloadLength < 65536) {
frame = Buffer.allocUnsafe(4 + payloadLength);
frame[0] = 0x81;
frame[1] = 126;
frame.writeUInt16BE(payloadLength, 2);
payload.copy(frame, 4);
} else {
throw new Error('Payload too large');
}
return frame;
}
// Enhanced message processing function
function processTrainingMessage(message, timestamp) {
const messageType = message.type || 'unknown';
@@ -443,13 +466,15 @@ req.on('upgrade', (res, socket, head) => {
const pingInterval = setInterval(() => {
if (socket.writable && !jobCompleted) {
try {
const pingFrame = Buffer.from([0x89, 0x00]);
socket.write(pingFrame);
// Send JSON ping message instead of binary frame
const pingMessage = JSON.stringify({ type: 'ping' });
const textFrame = createTextFrame(pingMessage);
socket.write(textFrame);
} catch (e) {
// Ignore ping errors
}
}
}, 5000); // Ping every 5 seconds
}, 5000);
// Heartbeat check - ensure we're still receiving messages
const heartbeatInterval = setInterval(() => {