110 lines
3.7 KiB
Python
110 lines
3.7 KiB
Python
|
|
"""
|
||
|
|
WebSocket Operations for Training Service
|
||
|
|
Simple WebSocket endpoint that connects clients and receives broadcasts from RabbitMQ
|
||
|
|
"""
|
||
|
|
|
||
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Path, Query
|
||
|
|
import structlog
|
||
|
|
|
||
|
|
from app.websocket.manager import websocket_manager
|
||
|
|
from shared.auth.jwt_handler import JWTHandler
|
||
|
|
from app.core.config import settings
|
||
|
|
|
||
|
|
logger = structlog.get_logger()
|
||
|
|
|
||
|
|
router = APIRouter(tags=["websocket"])
|
||
|
|
|
||
|
|
|
||
|
|
@router.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
||
|
|
async def training_progress_websocket(
|
||
|
|
websocket: WebSocket,
|
||
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
||
|
|
job_id: str = Path(..., description="Job ID"),
|
||
|
|
token: str = Query(..., description="Authentication token")
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
WebSocket endpoint for real-time training progress updates.
|
||
|
|
|
||
|
|
This endpoint:
|
||
|
|
1. Validates the authentication token
|
||
|
|
2. Accepts the WebSocket connection
|
||
|
|
3. Keeps the connection alive
|
||
|
|
4. Receives broadcasts from RabbitMQ (via WebSocket manager)
|
||
|
|
"""
|
||
|
|
|
||
|
|
# Validate token
|
||
|
|
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||
|
|
|
||
|
|
try:
|
||
|
|
payload = jwt_handler.verify_token(token)
|
||
|
|
if not payload:
|
||
|
|
await websocket.close(code=1008, reason="Invalid token")
|
||
|
|
logger.warning("WebSocket connection rejected - invalid token",
|
||
|
|
job_id=job_id,
|
||
|
|
tenant_id=tenant_id)
|
||
|
|
return
|
||
|
|
|
||
|
|
user_id = payload.get('user_id')
|
||
|
|
if not user_id:
|
||
|
|
await websocket.close(code=1008, reason="Invalid token payload")
|
||
|
|
logger.warning("WebSocket connection rejected - no user_id in token",
|
||
|
|
job_id=job_id,
|
||
|
|
tenant_id=tenant_id)
|
||
|
|
return
|
||
|
|
|
||
|
|
logger.info("WebSocket authentication successful",
|
||
|
|
user_id=user_id,
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
job_id=job_id)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
await websocket.close(code=1008, reason="Authentication failed")
|
||
|
|
logger.warning("WebSocket authentication failed",
|
||
|
|
job_id=job_id,
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
error=str(e))
|
||
|
|
return
|
||
|
|
|
||
|
|
# Connect to WebSocket manager
|
||
|
|
await websocket_manager.connect(job_id, websocket)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Send connection confirmation
|
||
|
|
await websocket.send_json({
|
||
|
|
"type": "connected",
|
||
|
|
"job_id": job_id,
|
||
|
|
"message": "Connected to training progress stream"
|
||
|
|
})
|
||
|
|
|
||
|
|
# Keep connection alive and handle client messages
|
||
|
|
ping_count = 0
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
# Receive messages from client (ping, etc.)
|
||
|
|
data = await websocket.receive_text()
|
||
|
|
|
||
|
|
# Handle ping/pong
|
||
|
|
if data == "ping":
|
||
|
|
await websocket.send_text("pong")
|
||
|
|
ping_count += 1
|
||
|
|
logger.info("WebSocket ping/pong",
|
||
|
|
job_id=job_id,
|
||
|
|
ping_count=ping_count,
|
||
|
|
connection_healthy=True)
|
||
|
|
|
||
|
|
except WebSocketDisconnect:
|
||
|
|
logger.info("Client disconnected", job_id=job_id)
|
||
|
|
break
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("Error in WebSocket message loop",
|
||
|
|
job_id=job_id,
|
||
|
|
error=str(e))
|
||
|
|
break
|
||
|
|
|
||
|
|
finally:
|
||
|
|
# Disconnect from manager
|
||
|
|
await websocket_manager.disconnect(job_id, websocket)
|
||
|
|
logger.info("WebSocket connection closed",
|
||
|
|
job_id=job_id,
|
||
|
|
tenant_id=tenant_id)
|