REFACTOR external service and improve websocket training
This commit is contained in:
109
services/training/app/api/websocket_operations.py
Normal file
109
services/training/app/api/websocket_operations.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
WebSocket Operations for Training Service
|
||||
Simple WebSocket endpoint that connects clients and receives broadcasts from RabbitMQ
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Path, Query
|
||||
import structlog
|
||||
|
||||
from app.websocket.manager import websocket_manager
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter(tags=["websocket"])
|
||||
|
||||
|
||||
@router.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
||||
async def training_progress_websocket(
|
||||
websocket: WebSocket,
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
job_id: str = Path(..., description="Job ID"),
|
||||
token: str = Query(..., description="Authentication token")
|
||||
):
|
||||
"""
|
||||
WebSocket endpoint for real-time training progress updates.
|
||||
|
||||
This endpoint:
|
||||
1. Validates the authentication token
|
||||
2. Accepts the WebSocket connection
|
||||
3. Keeps the connection alive
|
||||
4. Receives broadcasts from RabbitMQ (via WebSocket manager)
|
||||
"""
|
||||
|
||||
# Validate token
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload:
|
||||
await websocket.close(code=1008, reason="Invalid token")
|
||||
logger.warning("WebSocket connection rejected - invalid token",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
return
|
||||
|
||||
user_id = payload.get('user_id')
|
||||
if not user_id:
|
||||
await websocket.close(code=1008, reason="Invalid token payload")
|
||||
logger.warning("WebSocket connection rejected - no user_id in token",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
return
|
||||
|
||||
logger.info("WebSocket authentication successful",
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id)
|
||||
|
||||
except Exception as e:
|
||||
await websocket.close(code=1008, reason="Authentication failed")
|
||||
logger.warning("WebSocket authentication failed",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return
|
||||
|
||||
# Connect to WebSocket manager
|
||||
await websocket_manager.connect(job_id, websocket)
|
||||
|
||||
try:
|
||||
# Send connection confirmation
|
||||
await websocket.send_json({
|
||||
"type": "connected",
|
||||
"job_id": job_id,
|
||||
"message": "Connected to training progress stream"
|
||||
})
|
||||
|
||||
# Keep connection alive and handle client messages
|
||||
ping_count = 0
|
||||
while True:
|
||||
try:
|
||||
# Receive messages from client (ping, etc.)
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# Handle ping/pong
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
ping_count += 1
|
||||
logger.info("WebSocket ping/pong",
|
||||
job_id=job_id,
|
||||
ping_count=ping_count,
|
||||
connection_healthy=True)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Client disconnected", job_id=job_id)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in WebSocket message loop",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
break
|
||||
|
||||
finally:
|
||||
# Disconnect from manager
|
||||
await websocket_manager.disconnect(job_id, websocket)
|
||||
logger.info("WebSocket connection closed",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
Reference in New Issue
Block a user