Files
bakery-ia/services/training/app/api/websocket_operations.py
Urtzi Alfaro 02f0c91a15 Fix UI issues
2025-12-29 19:33:35 +01:00

164 lines
6.3 KiB
Python

"""
WebSocket Operations for Training Service
Simple WebSocket endpoint that connects clients and receives broadcasts from RabbitMQ
"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Path, Query
import structlog
from app.websocket.manager import websocket_manager
from shared.auth.jwt_handler import JWTHandler
from app.core.config import settings
from app.services.training_service import EnhancedTrainingService
from shared.database.base import create_database_manager
logger = structlog.get_logger()
router = APIRouter(tags=["websocket"])
def get_enhanced_training_service():
"""Create EnhancedTrainingService instance"""
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
return EnhancedTrainingService(database_manager)
@router.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
async def training_progress_websocket(
websocket: WebSocket,
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
token: str = Query(..., description="Authentication token")
):
"""
WebSocket endpoint for real-time training progress updates.
This endpoint:
1. Validates the authentication token
2. Accepts the WebSocket connection
3. Keeps the connection alive
4. Receives broadcasts from RabbitMQ (via WebSocket manager)
"""
# Validate token
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload:
await websocket.close(code=1008, reason="Invalid token")
logger.warning("WebSocket connection rejected - invalid token",
job_id=job_id,
tenant_id=tenant_id)
return
user_id = payload.get('user_id')
if not user_id:
await websocket.close(code=1008, reason="Invalid token payload")
logger.warning("WebSocket connection rejected - no user_id in token",
job_id=job_id,
tenant_id=tenant_id)
return
logger.info("WebSocket authentication successful",
user_id=user_id,
tenant_id=tenant_id,
job_id=job_id)
except Exception as e:
await websocket.close(code=1008, reason="Authentication failed")
logger.warning("WebSocket authentication failed",
job_id=job_id,
tenant_id=tenant_id,
error=str(e))
return
# Connect to WebSocket manager
await websocket_manager.connect(job_id, websocket)
# Helper function to send current job status
async def send_current_status():
"""Fetch and send the current job status to the client"""
try:
training_service = get_enhanced_training_service()
status_info = await training_service.get_training_status(job_id)
if status_info and not status_info.get("error"):
# Map status to WebSocket message type
ws_type = "progress"
if status_info.get("status") == "completed":
ws_type = "completed"
elif status_info.get("status") == "failed":
ws_type = "failed"
await websocket.send_json({
"type": ws_type,
"job_id": job_id,
"data": {
"progress": status_info.get("progress", 0),
"current_step": status_info.get("current_step"),
"status": status_info.get("status"),
"products_total": status_info.get("products_total", 0),
"products_completed": status_info.get("products_completed", 0),
"products_failed": status_info.get("products_failed", 0),
"estimated_time_remaining_seconds": status_info.get("estimated_time_remaining_seconds"),
"message": status_info.get("message")
}
})
logger.info("Sent current job status to client",
job_id=job_id,
status=status_info.get("status"),
progress=status_info.get("progress"))
except Exception as e:
logger.error("Failed to send current job status",
job_id=job_id,
error=str(e))
try:
# Send connection confirmation
await websocket.send_json({
"type": "connected",
"job_id": job_id,
"message": "Connected to training progress stream"
})
# Immediately send current job status after connection
# This handles the race condition where training completes before WebSocket connects
await send_current_status()
# Keep connection alive and handle client messages
ping_count = 0
while True:
try:
# Receive messages from client (ping, get_status, etc.)
data = await websocket.receive_text()
# Handle ping/pong
if data == "ping":
await websocket.send_text("pong")
ping_count += 1
logger.debug("WebSocket ping/pong",
job_id=job_id,
ping_count=ping_count,
connection_healthy=True)
# Handle get_status request
elif data == "get_status":
await send_current_status()
logger.info("Status requested by client", job_id=job_id)
except WebSocketDisconnect:
logger.info("Client disconnected", job_id=job_id)
break
except Exception as e:
logger.error("Error in WebSocket message loop",
job_id=job_id,
error=str(e))
break
finally:
# Disconnect from manager
await websocket_manager.disconnect(job_id, websocket)
logger.info("WebSocket connection closed",
job_id=job_id,
tenant_id=tenant_id)