2025-10-09 14:11:02 +02:00
|
|
|
"""
|
|
|
|
|
WebSocket Operations for Training Service
|
|
|
|
|
Simple WebSocket endpoint that connects clients and receives broadcasts from RabbitMQ
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Path, Query
|
|
|
|
|
import structlog
|
|
|
|
|
|
|
|
|
|
from app.websocket.manager import websocket_manager
|
|
|
|
|
from shared.auth.jwt_handler import JWTHandler
|
|
|
|
|
from app.core.config import settings
|
2025-12-29 19:33:35 +01:00
|
|
|
from app.services.training_service import EnhancedTrainingService
|
|
|
|
|
from shared.database.base import create_database_manager
|
2025-10-09 14:11:02 +02:00
|
|
|
|
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
|
|
|
|
router = APIRouter(tags=["websocket"])
|
|
|
|
|
|
|
|
|
|
|
2025-12-29 19:33:35 +01:00
|
|
|
def get_enhanced_training_service():
|
|
|
|
|
"""Create EnhancedTrainingService instance"""
|
|
|
|
|
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
|
|
|
|
return EnhancedTrainingService(database_manager)
|
|
|
|
|
|
|
|
|
|
|
2025-10-09 14:11:02 +02:00
|
|
|
@router.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
|
|
|
|
async def training_progress_websocket(
|
|
|
|
|
websocket: WebSocket,
|
|
|
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
|
|
|
job_id: str = Path(..., description="Job ID"),
|
|
|
|
|
token: str = Query(..., description="Authentication token")
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
WebSocket endpoint for real-time training progress updates.
|
|
|
|
|
|
|
|
|
|
This endpoint:
|
|
|
|
|
1. Validates the authentication token
|
|
|
|
|
2. Accepts the WebSocket connection
|
|
|
|
|
3. Keeps the connection alive
|
|
|
|
|
4. Receives broadcasts from RabbitMQ (via WebSocket manager)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Validate token
|
|
|
|
|
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
payload = jwt_handler.verify_token(token)
|
|
|
|
|
if not payload:
|
|
|
|
|
await websocket.close(code=1008, reason="Invalid token")
|
|
|
|
|
logger.warning("WebSocket connection rejected - invalid token",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
tenant_id=tenant_id)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
user_id = payload.get('user_id')
|
|
|
|
|
if not user_id:
|
|
|
|
|
await websocket.close(code=1008, reason="Invalid token payload")
|
|
|
|
|
logger.warning("WebSocket connection rejected - no user_id in token",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
tenant_id=tenant_id)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
logger.info("WebSocket authentication successful",
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
job_id=job_id)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
await websocket.close(code=1008, reason="Authentication failed")
|
|
|
|
|
logger.warning("WebSocket authentication failed",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Connect to WebSocket manager
|
|
|
|
|
await websocket_manager.connect(job_id, websocket)
|
|
|
|
|
|
2025-12-29 19:33:35 +01:00
|
|
|
# Helper function to send current job status
|
|
|
|
|
async def send_current_status():
|
|
|
|
|
"""Fetch and send the current job status to the client"""
|
|
|
|
|
try:
|
|
|
|
|
training_service = get_enhanced_training_service()
|
|
|
|
|
status_info = await training_service.get_training_status(job_id)
|
|
|
|
|
|
|
|
|
|
if status_info and not status_info.get("error"):
|
|
|
|
|
# Map status to WebSocket message type
|
|
|
|
|
ws_type = "progress"
|
|
|
|
|
if status_info.get("status") == "completed":
|
|
|
|
|
ws_type = "completed"
|
|
|
|
|
elif status_info.get("status") == "failed":
|
|
|
|
|
ws_type = "failed"
|
|
|
|
|
|
|
|
|
|
await websocket.send_json({
|
|
|
|
|
"type": ws_type,
|
|
|
|
|
"job_id": job_id,
|
|
|
|
|
"data": {
|
|
|
|
|
"progress": status_info.get("progress", 0),
|
|
|
|
|
"current_step": status_info.get("current_step"),
|
|
|
|
|
"status": status_info.get("status"),
|
|
|
|
|
"products_total": status_info.get("products_total", 0),
|
|
|
|
|
"products_completed": status_info.get("products_completed", 0),
|
|
|
|
|
"products_failed": status_info.get("products_failed", 0),
|
|
|
|
|
"estimated_time_remaining_seconds": status_info.get("estimated_time_remaining_seconds"),
|
|
|
|
|
"message": status_info.get("message")
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
logger.info("Sent current job status to client",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
status=status_info.get("status"),
|
|
|
|
|
progress=status_info.get("progress"))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to send current job status",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
|
2025-10-09 14:11:02 +02:00
|
|
|
try:
|
|
|
|
|
# Send connection confirmation
|
|
|
|
|
await websocket.send_json({
|
|
|
|
|
"type": "connected",
|
|
|
|
|
"job_id": job_id,
|
|
|
|
|
"message": "Connected to training progress stream"
|
|
|
|
|
})
|
|
|
|
|
|
2025-12-29 19:33:35 +01:00
|
|
|
# Immediately send current job status after connection
|
|
|
|
|
# This handles the race condition where training completes before WebSocket connects
|
|
|
|
|
await send_current_status()
|
|
|
|
|
|
2025-10-09 14:11:02 +02:00
|
|
|
# Keep connection alive and handle client messages
|
|
|
|
|
ping_count = 0
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
2025-12-29 19:33:35 +01:00
|
|
|
# Receive messages from client (ping, get_status, etc.)
|
2025-10-09 14:11:02 +02:00
|
|
|
data = await websocket.receive_text()
|
|
|
|
|
|
|
|
|
|
# Handle ping/pong
|
|
|
|
|
if data == "ping":
|
|
|
|
|
await websocket.send_text("pong")
|
|
|
|
|
ping_count += 1
|
2025-12-29 19:33:35 +01:00
|
|
|
logger.debug("WebSocket ping/pong",
|
2025-10-09 14:11:02 +02:00
|
|
|
job_id=job_id,
|
|
|
|
|
ping_count=ping_count,
|
|
|
|
|
connection_healthy=True)
|
2025-12-29 19:33:35 +01:00
|
|
|
# Handle get_status request
|
|
|
|
|
elif data == "get_status":
|
|
|
|
|
await send_current_status()
|
|
|
|
|
logger.info("Status requested by client", job_id=job_id)
|
2025-10-09 14:11:02 +02:00
|
|
|
|
|
|
|
|
except WebSocketDisconnect:
|
|
|
|
|
logger.info("Client disconnected", job_id=job_id)
|
|
|
|
|
break
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Error in WebSocket message loop",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
# Disconnect from manager
|
|
|
|
|
await websocket_manager.disconnect(job_id, websocket)
|
|
|
|
|
logger.info("WebSocket connection closed",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
tenant_id=tenant_id)
|