Websocket fix 2
This commit is contained in:
@@ -82,31 +82,23 @@ async def training_progress_websocket(
|
||||
tenant_id: str,
|
||||
job_id: str
|
||||
):
|
||||
"""
|
||||
FIXED WebSocket endpoint for real-time training progress updates
|
||||
Prevents premature disconnection during training
|
||||
"""
|
||||
connection_id = f"{tenant_id}_{id(websocket)}"
|
||||
|
||||
# Accept connection
|
||||
await connection_manager.connect(websocket, job_id, connection_id)
|
||||
logger.info(f"WebSocket connection established for job {job_id}")
|
||||
|
||||
# Set up RabbitMQ consumer for this job
|
||||
consumer_task = None
|
||||
training_completed = False # Track training completion
|
||||
training_completed = False
|
||||
|
||||
try:
|
||||
# Start RabbitMQ consumer
|
||||
logger.info(f"Setting up RabbitMQ consumer for job {job_id}")
|
||||
consumer_task = asyncio.create_task(
|
||||
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
|
||||
)
|
||||
|
||||
# Give consumer time to set up
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Send initial status if available
|
||||
# Send initial status
|
||||
try:
|
||||
initial_status = await get_current_job_status(job_id, tenant_id)
|
||||
if initial_status:
|
||||
@@ -115,39 +107,50 @@ async def training_progress_websocket(
|
||||
"job_id": job_id,
|
||||
"data": initial_status
|
||||
})
|
||||
logger.info(f"Sent initial status for job {job_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send initial status: {e}")
|
||||
|
||||
# Keep connection alive - IMPROVED ERROR HANDLING
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
while not training_completed:
|
||||
try:
|
||||
# Wait for client messages with timeout
|
||||
# FIXED: Use receive() instead of receive_text()
|
||||
try:
|
||||
message = await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
|
||||
data = await asyncio.wait_for(websocket.receive(), timeout=30.0)
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
if message == "ping":
|
||||
await websocket.send_text("pong")
|
||||
logger.debug(f"Ping received from job {job_id}")
|
||||
elif message == "get_status":
|
||||
current_status = await get_current_job_status(job_id, tenant_id)
|
||||
if current_status:
|
||||
await websocket.send_json({
|
||||
"type": "current_status",
|
||||
"job_id": job_id,
|
||||
"data": current_status
|
||||
})
|
||||
elif message == "close":
|
||||
logger.info(f"Client requested connection close for job {job_id}")
|
||||
# Handle different message types
|
||||
if data["type"] == "websocket.receive":
|
||||
if "text" in data:
|
||||
message_text = data["text"]
|
||||
if message_text == "ping":
|
||||
await websocket.send_text("pong")
|
||||
logger.debug(f"Text ping received from job {job_id}")
|
||||
elif message_text == "get_status":
|
||||
current_status = await get_current_job_status(job_id, tenant_id)
|
||||
if current_status:
|
||||
await websocket.send_json({
|
||||
"type": "current_status",
|
||||
"job_id": job_id,
|
||||
"data": current_status
|
||||
})
|
||||
elif message_text == "close":
|
||||
logger.info(f"Client requested connection close for job {job_id}")
|
||||
break
|
||||
|
||||
elif "bytes" in data:
|
||||
# Handle binary messages (WebSocket ping frames)
|
||||
await websocket.send_text("pong")
|
||||
logger.debug(f"Binary ping received for job {job_id}")
|
||||
|
||||
elif data["type"] == "websocket.disconnect":
|
||||
logger.info(f"WebSocket disconnect message received for job {job_id}")
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No message received in 30 seconds - send heartbeat
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
if current_time - last_activity > 60: # 60 seconds of inactivity
|
||||
if current_time - last_activity > 60:
|
||||
logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat")
|
||||
|
||||
try:
|
||||
@@ -165,7 +168,11 @@ async def training_progress_websocket(
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for job {job_id}: {e}")
|
||||
# Don't break immediately - try to recover
|
||||
# Check if it's the specific "cannot call receive" error
|
||||
if "Cannot call" in str(e) and "disconnect message" in str(e):
|
||||
logger.error(f"FastAPI WebSocket disconnect error - connection already closed")
|
||||
break
|
||||
# Don't break immediately for other errors - try to recover
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
|
||||
@@ -174,19 +181,15 @@ async def training_progress_websocket(
|
||||
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
|
||||
|
||||
finally:
|
||||
# IMPROVED CLEANUP - Don't cancel consumer unless truly disconnecting
|
||||
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
|
||||
connection_manager.disconnect(job_id, connection_id)
|
||||
|
||||
# Only cancel consumer if we're truly done (not just a temporary error)
|
||||
if consumer_task and not consumer_task.done():
|
||||
if training_completed:
|
||||
logger.info(f"Training completed, cancelling consumer for job {job_id}")
|
||||
consumer_task.cancel()
|
||||
else:
|
||||
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
|
||||
# Let the consumer continue running for other potential connections
|
||||
# Don't cancel it unless we're sure the job is done
|
||||
|
||||
try:
|
||||
await consumer_task
|
||||
|
||||
Reference in New Issue
Block a user