Add base kubernetes support final fix 4
This commit is contained in:
@@ -82,11 +82,45 @@ async def training_progress_websocket(
|
||||
tenant_id: str,
|
||||
job_id: str
|
||||
):
|
||||
connection_id = f"{tenant_id}_{id(websocket)}"
|
||||
|
||||
# Validate token from query parameters
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Authentication token required")
|
||||
return
|
||||
|
||||
# Validate the token (use the same JWT handler as gateway)
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from app.core.config import settings
|
||||
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload:
|
||||
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid authentication token")
|
||||
return
|
||||
|
||||
# Verify user has access to this tenant
|
||||
user_id = payload.get('user_id')
|
||||
if not user_id:
|
||||
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid token payload")
|
||||
return
|
||||
|
||||
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
|
||||
await websocket.close(code=1008, reason="Token validation failed")
|
||||
return
|
||||
|
||||
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
|
||||
|
||||
await connection_manager.connect(websocket, job_id, connection_id)
|
||||
logger.info(f"WebSocket connection established for job {job_id}")
|
||||
|
||||
logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
|
||||
|
||||
consumer_task = None
|
||||
training_completed = False
|
||||
|
||||
@@ -100,11 +134,12 @@ async def training_progress_websocket(
|
||||
|
||||
while not training_completed:
|
||||
try:
|
||||
# FIXED: Use receive() instead of receive_text()
|
||||
# Coordinate with frontend 30s heartbeat + gateway 45s timeout
|
||||
# This should be longer than gateway timeout to avoid premature closure
|
||||
try:
|
||||
data = await asyncio.wait_for(websocket.receive(), timeout=30.0)
|
||||
data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
|
||||
# Handle different message types
|
||||
if data["type"] == "websocket.receive":
|
||||
if "text" in data:
|
||||
@@ -123,31 +158,41 @@ async def training_progress_websocket(
|
||||
elif message_text == "close":
|
||||
logger.info(f"Client requested connection close for job {job_id}")
|
||||
break
|
||||
|
||||
|
||||
elif "bytes" in data:
|
||||
# Handle binary messages (WebSocket ping frames)
|
||||
# Handle binary messages (WebSocket ping frames) - respond with text pong for compatibility
|
||||
await websocket.send_text("pong")
|
||||
logger.debug(f"Binary ping received for job {job_id}")
|
||||
|
||||
logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
|
||||
|
||||
elif data["type"] == "websocket.disconnect":
|
||||
logger.info(f"WebSocket disconnect message received for job {job_id}")
|
||||
break
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No message received in 30 seconds - send heartbeat
|
||||
# No message received in 60 seconds - this is now coordinated with gateway timeouts
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
if current_time - last_activity > 60:
|
||||
logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat")
|
||||
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "heartbeat",
|
||||
"job_id": job_id,
|
||||
"timestamp": str(datetime.datetime.now())
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
|
||||
break
|
||||
|
||||
# Send heartbeat only if we haven't received frontend ping for too long
|
||||
# Frontend sends ping every 30s, so 60s timeout + 30s grace = 90s before heartbeat
|
||||
if current_time - last_activity > 90: # 90 seconds of total inactivity
|
||||
logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
|
||||
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "heartbeat",
|
||||
"job_id": job_id,
|
||||
"timestamp": str(datetime.datetime.now()),
|
||||
"message": "Training service heartbeat - frontend inactive",
|
||||
"inactivity_seconds": int(current_time - last_activity)
|
||||
})
|
||||
last_activity = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
|
||||
break
|
||||
else:
|
||||
# Normal timeout, frontend should be sending ping every 30s
|
||||
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
|
||||
continue
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket client disconnected for job {job_id}")
|
||||
|
||||
Reference in New Issue
Block a user