Websocket fix 2
This commit is contained in:
@@ -82,31 +82,23 @@ async def training_progress_websocket(
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
job_id: str
|
job_id: str
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
FIXED WebSocket endpoint for real-time training progress updates
|
|
||||||
Prevents premature disconnection during training
|
|
||||||
"""
|
|
||||||
connection_id = f"{tenant_id}_{id(websocket)}"
|
connection_id = f"{tenant_id}_{id(websocket)}"
|
||||||
|
|
||||||
# Accept connection
|
|
||||||
await connection_manager.connect(websocket, job_id, connection_id)
|
await connection_manager.connect(websocket, job_id, connection_id)
|
||||||
logger.info(f"WebSocket connection established for job {job_id}")
|
logger.info(f"WebSocket connection established for job {job_id}")
|
||||||
|
|
||||||
# Set up RabbitMQ consumer for this job
|
|
||||||
consumer_task = None
|
consumer_task = None
|
||||||
training_completed = False # Track training completion
|
training_completed = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Start RabbitMQ consumer
|
# Start RabbitMQ consumer
|
||||||
logger.info(f"Setting up RabbitMQ consumer for job {job_id}")
|
|
||||||
consumer_task = asyncio.create_task(
|
consumer_task = asyncio.create_task(
|
||||||
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
|
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Give consumer time to set up
|
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
# Send initial status if available
|
# Send initial status
|
||||||
try:
|
try:
|
||||||
initial_status = await get_current_job_status(job_id, tenant_id)
|
initial_status = await get_current_job_status(job_id, tenant_id)
|
||||||
if initial_status:
|
if initial_status:
|
||||||
@@ -115,39 +107,50 @@ async def training_progress_websocket(
|
|||||||
"job_id": job_id,
|
"job_id": job_id,
|
||||||
"data": initial_status
|
"data": initial_status
|
||||||
})
|
})
|
||||||
logger.info(f"Sent initial status for job {job_id}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to send initial status: {e}")
|
logger.warning(f"Failed to send initial status: {e}")
|
||||||
|
|
||||||
# Keep connection alive - IMPROVED ERROR HANDLING
|
|
||||||
last_activity = asyncio.get_event_loop().time()
|
last_activity = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
while not training_completed:
|
while not training_completed:
|
||||||
try:
|
try:
|
||||||
# Wait for client messages with timeout
|
# FIXED: Use receive() instead of receive_text()
|
||||||
try:
|
try:
|
||||||
message = await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
|
data = await asyncio.wait_for(websocket.receive(), timeout=30.0)
|
||||||
last_activity = asyncio.get_event_loop().time()
|
last_activity = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
if message == "ping":
|
# Handle different message types
|
||||||
await websocket.send_text("pong")
|
if data["type"] == "websocket.receive":
|
||||||
logger.debug(f"Ping received from job {job_id}")
|
if "text" in data:
|
||||||
elif message == "get_status":
|
message_text = data["text"]
|
||||||
current_status = await get_current_job_status(job_id, tenant_id)
|
if message_text == "ping":
|
||||||
if current_status:
|
await websocket.send_text("pong")
|
||||||
await websocket.send_json({
|
logger.debug(f"Text ping received from job {job_id}")
|
||||||
"type": "current_status",
|
elif message_text == "get_status":
|
||||||
"job_id": job_id,
|
current_status = await get_current_job_status(job_id, tenant_id)
|
||||||
"data": current_status
|
if current_status:
|
||||||
})
|
await websocket.send_json({
|
||||||
elif message == "close":
|
"type": "current_status",
|
||||||
logger.info(f"Client requested connection close for job {job_id}")
|
"job_id": job_id,
|
||||||
|
"data": current_status
|
||||||
|
})
|
||||||
|
elif message_text == "close":
|
||||||
|
logger.info(f"Client requested connection close for job {job_id}")
|
||||||
|
break
|
||||||
|
|
||||||
|
elif "bytes" in data:
|
||||||
|
# Handle binary messages (WebSocket ping frames)
|
||||||
|
await websocket.send_text("pong")
|
||||||
|
logger.debug(f"Binary ping received for job {job_id}")
|
||||||
|
|
||||||
|
elif data["type"] == "websocket.disconnect":
|
||||||
|
logger.info(f"WebSocket disconnect message received for job {job_id}")
|
||||||
break
|
break
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# No message received in 30 seconds - send heartbeat
|
# No message received in 30 seconds - send heartbeat
|
||||||
current_time = asyncio.get_event_loop().time()
|
current_time = asyncio.get_event_loop().time()
|
||||||
if current_time - last_activity > 60: # 60 seconds of inactivity
|
if current_time - last_activity > 60:
|
||||||
logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat")
|
logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -165,7 +168,11 @@ async def training_progress_websocket(
|
|||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"WebSocket error for job {job_id}: {e}")
|
logger.error(f"WebSocket error for job {job_id}: {e}")
|
||||||
# Don't break immediately - try to recover
|
# Check if it's the specific "cannot call receive" error
|
||||||
|
if "Cannot call" in str(e) and "disconnect message" in str(e):
|
||||||
|
logger.error(f"FastAPI WebSocket disconnect error - connection already closed")
|
||||||
|
break
|
||||||
|
# Don't break immediately for other errors - try to recover
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
|
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
|
||||||
@@ -174,19 +181,15 @@ async def training_progress_websocket(
|
|||||||
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
|
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# IMPROVED CLEANUP - Don't cancel consumer unless truly disconnecting
|
|
||||||
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
|
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
|
||||||
connection_manager.disconnect(job_id, connection_id)
|
connection_manager.disconnect(job_id, connection_id)
|
||||||
|
|
||||||
# Only cancel consumer if we're truly done (not just a temporary error)
|
|
||||||
if consumer_task and not consumer_task.done():
|
if consumer_task and not consumer_task.done():
|
||||||
if training_completed:
|
if training_completed:
|
||||||
logger.info(f"Training completed, cancelling consumer for job {job_id}")
|
logger.info(f"Training completed, cancelling consumer for job {job_id}")
|
||||||
consumer_task.cancel()
|
consumer_task.cancel()
|
||||||
else:
|
else:
|
||||||
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
|
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
|
||||||
# Let the consumer continue running for other potential connections
|
|
||||||
# Don't cancel it unless we're sure the job is done
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await consumer_task
|
await consumer_task
|
||||||
|
|||||||
@@ -314,6 +314,29 @@ req.on('upgrade', (res, socket, head) => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
function createTextFrame(text) {
|
||||||
|
const payload = Buffer.from(text, 'utf8');
|
||||||
|
const payloadLength = payload.length;
|
||||||
|
|
||||||
|
let frame;
|
||||||
|
if (payloadLength < 126) {
|
||||||
|
frame = Buffer.allocUnsafe(2 + payloadLength);
|
||||||
|
frame[0] = 0x81; // Text frame, FIN=1
|
||||||
|
frame[1] = payloadLength;
|
||||||
|
payload.copy(frame, 2);
|
||||||
|
} else if (payloadLength < 65536) {
|
||||||
|
frame = Buffer.allocUnsafe(4 + payloadLength);
|
||||||
|
frame[0] = 0x81;
|
||||||
|
frame[1] = 126;
|
||||||
|
frame.writeUInt16BE(payloadLength, 2);
|
||||||
|
payload.copy(frame, 4);
|
||||||
|
} else {
|
||||||
|
throw new Error('Payload too large');
|
||||||
|
}
|
||||||
|
|
||||||
|
return frame;
|
||||||
|
}
|
||||||
|
|
||||||
// Enhanced message processing function
|
// Enhanced message processing function
|
||||||
function processTrainingMessage(message, timestamp) {
|
function processTrainingMessage(message, timestamp) {
|
||||||
const messageType = message.type || 'unknown';
|
const messageType = message.type || 'unknown';
|
||||||
@@ -443,13 +466,15 @@ req.on('upgrade', (res, socket, head) => {
|
|||||||
const pingInterval = setInterval(() => {
|
const pingInterval = setInterval(() => {
|
||||||
if (socket.writable && !jobCompleted) {
|
if (socket.writable && !jobCompleted) {
|
||||||
try {
|
try {
|
||||||
const pingFrame = Buffer.from([0x89, 0x00]);
|
// Send JSON ping message instead of binary frame
|
||||||
socket.write(pingFrame);
|
const pingMessage = JSON.stringify({ type: 'ping' });
|
||||||
|
const textFrame = createTextFrame(pingMessage);
|
||||||
|
socket.write(textFrame);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
// Ignore ping errors
|
// Ignore ping errors
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, 5000); // Ping every 5 seconds
|
}, 5000);
|
||||||
|
|
||||||
// Heartbeat check - ensure we're still receiving messages
|
// Heartbeat check - ensure we're still receiving messages
|
||||||
const heartbeatInterval = setInterval(() => {
|
const heartbeatInterval = setInterval(() => {
|
||||||
|
|||||||
Reference in New Issue
Block a user