REFACTOR external service and improve websocket training

This commit is contained in:
Urtzi Alfaro
2025-10-09 14:11:02 +02:00
parent 7c72f83c51
commit 3c689b4f98
111 changed files with 13289 additions and 2374 deletions

View File

@@ -255,28 +255,59 @@ async def events_stream(request: Request, tenant_id: str):
@app.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
"""
WebSocket proxy that forwards connections directly to training service.
Acts as a pure proxy - does NOT handle websocket logic, just forwards to training service.
All auth, message handling, and business logic is in the training service.
Simple WebSocket proxy with token verification only.
Validates the token and forwards the connection to the training service.
"""
# Get token from query params (required for training service authentication)
# Get token from query params
token = websocket.query_params.get("token")
if not token:
logger.warning(f"WebSocket proxy rejected - missing token for job {job_id}")
logger.warning("WebSocket proxy rejected - missing token",
job_id=job_id,
tenant_id=tenant_id)
await websocket.accept()
await websocket.close(code=1008, reason="Authentication token required")
return
# Accept the connection immediately
# Verify token
from shared.auth.jwt_handler import JWTHandler
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload or not payload.get('user_id'):
logger.warning("WebSocket proxy rejected - invalid token",
job_id=job_id,
tenant_id=tenant_id)
await websocket.accept()
await websocket.close(code=1008, reason="Invalid token")
return
logger.info("WebSocket proxy - token verified",
user_id=payload.get('user_id'),
tenant_id=tenant_id,
job_id=job_id)
except Exception as e:
logger.warning("WebSocket proxy - token verification failed",
job_id=job_id,
error=str(e))
await websocket.accept()
await websocket.close(code=1008, reason="Token verification failed")
return
# Accept the connection
await websocket.accept()
logger.info(f"Gateway proxying WebSocket to training service for job {job_id}, tenant {tenant_id}")
# Build WebSocket URL to training service - forward to the exact same path
# Build WebSocket URL to training service
training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/')
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
logger.info("Gateway proxying WebSocket to training service",
job_id=job_id,
training_ws_url=training_ws_url.replace(token, '***'))
training_ws = None
try:
@@ -285,17 +316,15 @@ async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_
training_ws = await websockets.connect(
training_ws_url,
ping_interval=None, # Let training service handle heartbeat
ping_timeout=None,
close_timeout=10,
open_timeout=30, # Allow time for training service to setup
max_size=2**20,
max_queue=32
ping_interval=120, # Send ping every 2 minutes (tolerates long training operations)
ping_timeout=60, # Wait up to 1 minute for pong (graceful timeout)
close_timeout=60, # Increase close timeout for graceful shutdown
open_timeout=30
)
logger.info(f"Gateway connected to training service WebSocket for job {job_id}")
logger.info("Gateway connected to training service WebSocket", job_id=job_id)
async def forward_to_training():
async def forward_frontend_to_training():
"""Forward messages from frontend to training service"""
try:
while training_ws and training_ws.open:
@@ -304,55 +333,58 @@ async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_
if data.get("type") == "websocket.receive":
if "text" in data:
await training_ws.send(data["text"])
logger.debug(f"Gateway forwarded frontend->training: {data['text'][:100]}")
elif "bytes" in data:
await training_ws.send(data["bytes"])
elif data.get("type") == "websocket.disconnect":
logger.info(f"Frontend disconnected for job {job_id}")
break
except Exception as e:
logger.error(f"Error forwarding frontend->training for job {job_id}: {e}")
logger.debug("Frontend to training forward ended", error=str(e))
async def forward_to_frontend():
async def forward_training_to_frontend():
"""Forward messages from training service to frontend"""
message_count = 0
try:
while training_ws and training_ws.open:
message = await training_ws.recv()
await websocket.send_text(message)
logger.debug(f"Gateway forwarded training->frontend: {message[:100]}")
message_count += 1
# Log every 10th message to track connectivity
if message_count % 10 == 0:
logger.debug("WebSocket proxy active",
job_id=job_id,
messages_forwarded=message_count)
except Exception as e:
logger.error(f"Error forwarding training->frontend for job {job_id}: {e}")
logger.info("Training to frontend forward ended",
job_id=job_id,
messages_forwarded=message_count,
error=str(e))
# Run both forwarding tasks concurrently
await asyncio.gather(
forward_to_training(),
forward_to_frontend(),
forward_frontend_to_training(),
forward_training_to_frontend(),
return_exceptions=True
)
except websockets.exceptions.ConnectionClosedError as e:
logger.warning(f"Training service WebSocket closed for job {job_id}: {e}")
except websockets.exceptions.WebSocketException as e:
logger.error(f"WebSocket exception for job {job_id}: {e}")
except Exception as e:
logger.error(f"WebSocket proxy error for job {job_id}: {e}")
logger.error("WebSocket proxy error", job_id=job_id, error=str(e))
finally:
# Cleanup
if training_ws and not training_ws.closed:
try:
await training_ws.close()
logger.info(f"Closed training service WebSocket for job {job_id}")
except Exception as e:
logger.warning(f"Error closing training service WebSocket for job {job_id}: {e}")
except:
pass
try:
if not websocket.client_state.name == 'DISCONNECTED':
await websocket.close(code=1000, reason="Proxy closed")
except Exception as e:
logger.warning(f"Error closing frontend WebSocket for job {job_id}: {e}")
except:
pass
logger.info(f"Gateway WebSocket proxy cleanup completed for job {job_id}")
logger.info("WebSocket proxy connection closed", job_id=job_id)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8000)