Add base kubernetes support final fix 4

This commit is contained in:
Urtzi Alfaro
2025-09-29 07:54:25 +02:00
parent 57f77638cc
commit 4777e59e7a
14 changed files with 1041 additions and 167 deletions

View File

@@ -82,11 +82,45 @@ async def training_progress_websocket(
tenant_id: str,
job_id: str
):
connection_id = f"{tenant_id}_{id(websocket)}"
# Validate token from query parameters
token = websocket.query_params.get("token")
if not token:
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
await websocket.close(code=1008, reason="Authentication token required")
return
# Validate the token (use the same JWT handler as gateway)
from shared.auth.jwt_handler import JWTHandler
from app.core.config import settings
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload:
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
await websocket.close(code=1008, reason="Invalid authentication token")
return
# Verify user has access to this tenant
user_id = payload.get('user_id')
if not user_id:
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
await websocket.close(code=1008, reason="Invalid token payload")
return
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
except Exception as e:
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
await websocket.close(code=1008, reason="Token validation failed")
return
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
await connection_manager.connect(websocket, job_id, connection_id)
logger.info(f"WebSocket connection established for job {job_id}")
logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
consumer_task = None
training_completed = False
@@ -100,11 +134,12 @@ async def training_progress_websocket(
while not training_completed:
try:
# FIXED: Use receive() instead of receive_text()
# Coordinate with frontend 30s heartbeat + gateway 45s timeout
# This should be longer than gateway timeout to avoid premature closure
try:
data = await asyncio.wait_for(websocket.receive(), timeout=30.0)
data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
last_activity = asyncio.get_event_loop().time()
# Handle different message types
if data["type"] == "websocket.receive":
if "text" in data:
@@ -123,31 +158,41 @@ async def training_progress_websocket(
elif message_text == "close":
logger.info(f"Client requested connection close for job {job_id}")
break
elif "bytes" in data:
# Handle binary messages (WebSocket ping frames)
# Handle binary messages (WebSocket ping frames) - respond with text pong for compatibility
await websocket.send_text("pong")
logger.debug(f"Binary ping received for job {job_id}")
logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
elif data["type"] == "websocket.disconnect":
logger.info(f"WebSocket disconnect message received for job {job_id}")
break
except asyncio.TimeoutError:
# No message received in 30 seconds - send heartbeat
# No message received in 60 seconds - this is now coordinated with gateway timeouts
current_time = asyncio.get_event_loop().time()
if current_time - last_activity > 60:
logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat")
try:
await websocket.send_json({
"type": "heartbeat",
"job_id": job_id,
"timestamp": str(datetime.datetime.now())
})
except Exception as e:
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
break
# Send heartbeat only if we haven't received frontend ping for too long
# Frontend sends ping every 30s, so 60s timeout + 30s grace = 90s before heartbeat
if current_time - last_activity > 90: # 90 seconds of total inactivity
logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
try:
await websocket.send_json({
"type": "heartbeat",
"job_id": job_id,
"timestamp": str(datetime.datetime.now()),
"message": "Training service heartbeat - frontend inactive",
"inactivity_seconds": int(current_time - last_activity)
})
last_activity = current_time
except Exception as e:
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
break
else:
# Normal timeout, frontend should be sending ping every 30s
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
continue
except WebSocketDisconnect:
logger.info(f"WebSocket client disconnected for job {job_id}")