Improve the event messaging for training service 2

This commit is contained in:
Urtzi Alfaro
2025-07-31 15:34:35 +02:00
parent 923b2d48d2
commit e581a144be
5 changed files with 288 additions and 10 deletions

View File

@@ -0,0 +1,257 @@
# services/training/app/api/websocket.py
"""
WebSocket endpoints for real-time training progress updates
"""
import json
import asyncio
import logging
from typing import Dict, Any
from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException
from fastapi.routing import APIRouter
from app.services.messaging import training_publisher
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep
)
logger = logging.getLogger(__name__)
# Create WebSocket router
websocket_router = APIRouter()
class ConnectionManager:
"""Manage WebSocket connections for training progress"""
def __init__(self):
self.active_connections: Dict[str, Dict[str, WebSocket]] = {}
# Structure: {job_id: {connection_id: websocket}}
async def connect(self, websocket: WebSocket, job_id: str, connection_id: str):
"""Accept WebSocket connection and register it"""
await websocket.accept()
if job_id not in self.active_connections:
self.active_connections[job_id] = {}
self.active_connections[job_id][connection_id] = websocket
logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}")
def disconnect(self, job_id: str, connection_id: str):
"""Remove WebSocket connection"""
if job_id in self.active_connections:
self.active_connections[job_id].pop(connection_id, None)
if not self.active_connections[job_id]:
del self.active_connections[job_id]
logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}")
async def send_to_job(self, job_id: str, message: dict):
"""Send message to all connections for a specific job"""
if job_id not in self.active_connections:
return
# Send to all connections for this job
disconnected_connections = []
for connection_id, websocket in self.active_connections[job_id].items():
try:
await websocket.send_json(message)
except Exception as e:
logger.warning(f"Failed to send message to connection {connection_id}: {e}")
disconnected_connections.append(connection_id)
# Clean up disconnected connections
for connection_id in disconnected_connections:
self.disconnect(job_id, connection_id)
# Global connection manager
connection_manager = ConnectionManager()
@websocket_router.websocket("/tenants/{tenant_id}/training/jobs/{job_id}/live")
async def training_progress_websocket(
websocket: WebSocket,
tenant_id: str,
job_id: str
):
"""
WebSocket endpoint for real-time training progress updates
Message format sent to client:
{
"type": "progress" | "completed" | "failed" | "step_completed",
"job_id": "job_123",
"timestamp": "2025-07-30T19:08:53Z",
"data": {
"progress": 45,
"current_step": "Training model for bread",
"current_product": "bread",
"products_completed": 2,
"products_total": 5,
"estimated_time_remaining_minutes": 8
}
}
"""
connection_id = f"{tenant_id}_{id(websocket)}"
# Accept connection
await connection_manager.connect(websocket, job_id, connection_id)
# Set up RabbitMQ consumer for this job
consumer_task = None
try:
# Start RabbitMQ consumer
consumer_task = asyncio.create_task(
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
)
# Send initial status if available
try:
# You can fetch current job status from database here
initial_status = await get_current_job_status(job_id, tenant_id)
if initial_status:
await websocket.send_json({
"type": "initial_status",
"job_id": job_id,
"data": initial_status
})
except Exception as e:
logger.warning(f"Failed to send initial status: {e}")
# Keep connection alive and handle client messages
while True:
try:
# Wait for client ping or other messages
message = await websocket.receive_text()
if message == "ping":
await websocket.send_text("pong")
elif message == "get_status":
# Send current status on demand
current_status = await get_current_job_status(job_id, tenant_id)
if current_status:
await websocket.send_json({
"type": "current_status",
"job_id": job_id,
"data": current_status
})
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected for job {job_id}")
break
except Exception as e:
logger.error(f"WebSocket error for job {job_id}: {e}")
break
finally:
# Clean up
connection_manager.disconnect(job_id, connection_id)
if consumer_task and not consumer_task.done():
consumer_task.cancel()
try:
await consumer_task
except asyncio.CancelledError:
pass
async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
"""Set up RabbitMQ consumer to listen for training events for a specific job"""
try:
# Create a unique queue for this WebSocket connection
queue_name = f"websocket_training_{job_id}_{tenant_id}"
async def handle_training_message(message):
"""Handle incoming RabbitMQ messages and forward to WebSocket"""
try:
# Parse the message
body = message.body.decode()
data = json.loads(body)
# Extract event data
event_type = data.get("event_type", "unknown")
event_data = data.get("data", {})
# Only process messages for this specific job
if event_data.get("job_id") != job_id:
await message.ack()
return
# Transform RabbitMQ message to WebSocket message format
websocket_message = {
"type": map_event_type_to_websocket_type(event_type),
"job_id": job_id,
"timestamp": data.get("timestamp"),
"data": event_data
}
# Send to all WebSocket connections for this job
await connection_manager.send_to_job(job_id, websocket_message)
# Acknowledge the message
await message.ack()
logger.debug(f"Forwarded training event to WebSocket: {event_type}")
except Exception as e:
logger.error(f"Error handling training message for WebSocket: {e}")
await message.nack(requeue=False)
# Subscribe to training events
await training_publisher.consume_events(
exchange_name="training.events",
queue_name=queue_name,
routing_key="training.*", # Listen to all training events
callback=handle_training_message
)
except Exception as e:
logger.error(f"Failed to set up RabbitMQ consumer for WebSocket: {e}")
def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str:
"""Map RabbitMQ event types to WebSocket message types"""
mapping = {
"training.started": "started",
"training.progress": "progress",
"training.completed": "completed",
"training.failed": "failed",
"training.cancelled": "cancelled",
"training.step.completed": "step_completed",
"training.product.started": "product_started",
"training.product.completed": "product_completed",
"training.product.failed": "product_failed",
"training.model.trained": "model_trained",
"training.data.validation.started": "validation_started",
"training.data.validation.completed": "validation_completed"
}
return mapping.get(rabbitmq_event_type, "unknown")
async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
"""Get current job status from database or cache"""
try:
# This should query your database for current job status
# For now, return a placeholder - implement based on your database schema
from app.core.database import get_db_session
from app.models.training import ModelTrainingLog # Assuming you have this model
async with get_db_session() as db:
# Query your training job status
# This is a placeholder - adjust based on your actual database models
pass
# Placeholder return - replace with actual database query
return {
"job_id": job_id,
"status": "running", # or "completed", "failed", etc.
"progress": 0,
"current_step": "Starting...",
"started_at": "2025-07-30T19:00:00Z"
}
except Exception as e:
logger.error(f"Failed to get current job status: {e}")
return None