Improve the event messaging for training service 2

This commit is contained in:
Urtzi Alfaro
2025-07-31 15:34:35 +02:00
parent 923b2d48d2
commit e581a144be
5 changed files with 288 additions and 10 deletions

View File

@@ -0,0 +1,257 @@
# services/training/app/api/websocket.py
"""
WebSocket endpoints for real-time training progress updates
"""
import json
import asyncio
import logging
from typing import Dict, Any
from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException
from fastapi.routing import APIRouter
from app.services.messaging import training_publisher
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep
)
logger = logging.getLogger(__name__)
# Create WebSocket router
websocket_router = APIRouter()
class ConnectionManager:
"""Manage WebSocket connections for training progress"""
def __init__(self):
self.active_connections: Dict[str, Dict[str, WebSocket]] = {}
# Structure: {job_id: {connection_id: websocket}}
async def connect(self, websocket: WebSocket, job_id: str, connection_id: str):
"""Accept WebSocket connection and register it"""
await websocket.accept()
if job_id not in self.active_connections:
self.active_connections[job_id] = {}
self.active_connections[job_id][connection_id] = websocket
logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}")
def disconnect(self, job_id: str, connection_id: str):
"""Remove WebSocket connection"""
if job_id in self.active_connections:
self.active_connections[job_id].pop(connection_id, None)
if not self.active_connections[job_id]:
del self.active_connections[job_id]
logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}")
async def send_to_job(self, job_id: str, message: dict):
"""Send message to all connections for a specific job"""
if job_id not in self.active_connections:
return
# Send to all connections for this job
disconnected_connections = []
for connection_id, websocket in self.active_connections[job_id].items():
try:
await websocket.send_json(message)
except Exception as e:
logger.warning(f"Failed to send message to connection {connection_id}: {e}")
disconnected_connections.append(connection_id)
# Clean up disconnected connections
for connection_id in disconnected_connections:
self.disconnect(job_id, connection_id)
# Global connection manager
connection_manager = ConnectionManager()
@websocket_router.websocket("/tenants/{tenant_id}/training/jobs/{job_id}/live")
async def training_progress_websocket(
websocket: WebSocket,
tenant_id: str,
job_id: str
):
"""
WebSocket endpoint for real-time training progress updates
Message format sent to client:
{
"type": "progress" | "completed" | "failed" | "step_completed",
"job_id": "job_123",
"timestamp": "2025-07-30T19:08:53Z",
"data": {
"progress": 45,
"current_step": "Training model for bread",
"current_product": "bread",
"products_completed": 2,
"products_total": 5,
"estimated_time_remaining_minutes": 8
}
}
"""
connection_id = f"{tenant_id}_{id(websocket)}"
# Accept connection
await connection_manager.connect(websocket, job_id, connection_id)
# Set up RabbitMQ consumer for this job
consumer_task = None
try:
# Start RabbitMQ consumer
consumer_task = asyncio.create_task(
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
)
# Send initial status if available
try:
# You can fetch current job status from database here
initial_status = await get_current_job_status(job_id, tenant_id)
if initial_status:
await websocket.send_json({
"type": "initial_status",
"job_id": job_id,
"data": initial_status
})
except Exception as e:
logger.warning(f"Failed to send initial status: {e}")
# Keep connection alive and handle client messages
while True:
try:
# Wait for client ping or other messages
message = await websocket.receive_text()
if message == "ping":
await websocket.send_text("pong")
elif message == "get_status":
# Send current status on demand
current_status = await get_current_job_status(job_id, tenant_id)
if current_status:
await websocket.send_json({
"type": "current_status",
"job_id": job_id,
"data": current_status
})
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected for job {job_id}")
break
except Exception as e:
logger.error(f"WebSocket error for job {job_id}: {e}")
break
finally:
# Clean up
connection_manager.disconnect(job_id, connection_id)
if consumer_task and not consumer_task.done():
consumer_task.cancel()
try:
await consumer_task
except asyncio.CancelledError:
pass
async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
"""Set up RabbitMQ consumer to listen for training events for a specific job"""
try:
# Create a unique queue for this WebSocket connection
queue_name = f"websocket_training_{job_id}_{tenant_id}"
async def handle_training_message(message):
"""Handle incoming RabbitMQ messages and forward to WebSocket"""
try:
# Parse the message
body = message.body.decode()
data = json.loads(body)
# Extract event data
event_type = data.get("event_type", "unknown")
event_data = data.get("data", {})
# Only process messages for this specific job
if event_data.get("job_id") != job_id:
await message.ack()
return
# Transform RabbitMQ message to WebSocket message format
websocket_message = {
"type": map_event_type_to_websocket_type(event_type),
"job_id": job_id,
"timestamp": data.get("timestamp"),
"data": event_data
}
# Send to all WebSocket connections for this job
await connection_manager.send_to_job(job_id, websocket_message)
# Acknowledge the message
await message.ack()
logger.debug(f"Forwarded training event to WebSocket: {event_type}")
except Exception as e:
logger.error(f"Error handling training message for WebSocket: {e}")
await message.nack(requeue=False)
# Subscribe to training events
await training_publisher.consume_events(
exchange_name="training.events",
queue_name=queue_name,
routing_key="training.*", # Listen to all training events
callback=handle_training_message
)
except Exception as e:
logger.error(f"Failed to set up RabbitMQ consumer for WebSocket: {e}")
def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str:
"""Map RabbitMQ event types to WebSocket message types"""
mapping = {
"training.started": "started",
"training.progress": "progress",
"training.completed": "completed",
"training.failed": "failed",
"training.cancelled": "cancelled",
"training.step.completed": "step_completed",
"training.product.started": "product_started",
"training.product.completed": "product_completed",
"training.product.failed": "product_failed",
"training.model.trained": "model_trained",
"training.data.validation.started": "validation_started",
"training.data.validation.completed": "validation_completed"
}
return mapping.get(rabbitmq_event_type, "unknown")
async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
"""Get current job status from database or cache"""
try:
# This should query your database for current job status
# For now, return a placeholder - implement based on your database schema
from app.core.database import get_db_session
from app.models.training import ModelTrainingLog # Assuming you have this model
async with get_db_session() as db:
# Query your training job status
# This is a placeholder - adjust based on your actual database models
pass
# Placeholder return - replace with actual database query
return {
"job_id": job_id,
"status": "running", # or "completed", "failed", etc.
"progress": 0,
"current_step": "Starting...",
"started_at": "2025-07-30T19:00:00Z"
}
except Exception as e:
logger.error(f"Failed to get current job status: {e}")
return None

View File

@@ -18,6 +18,7 @@ import uvicorn
from app.core.config import settings
from app.core.database import initialize_training_database, cleanup_training_database
from app.api import training, models
from app.api.websocket import websocket_router
from app.services.messaging import setup_messaging, cleanup_messaging
from shared.monitoring.logging import setup_logging
from shared.monitoring.metrics import MetricsCollector
@@ -176,6 +177,7 @@ async def global_exception_handler(request: Request, exc: Exception):
# Include API routers
app.include_router(training.router, prefix="/api/v1", tags=["training"])
app.include_router(models.router, prefix="/api/v1", tags=["models"])
app.include_router(websocket_router, prefix="/api/v1/ws", tags=["websocket"])
# Health check endpoints

View File

@@ -18,6 +18,13 @@ from app.core.config import settings
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed
)
logger = logging.getLogger(__name__)
class BakeryMLTrainer:

View File

@@ -63,6 +63,15 @@ def safe_json_serialize(data: Dict[str, Any]) -> Dict[str, Any]:
"""
return serialize_for_json(data)
async def setup_websocket_message_routing():
"""Set up message routing for WebSocket connections"""
try:
# This will be called from the WebSocket endpoint
# to set up the consumer for a specific job
pass
except Exception as e:
logger.error(f"Failed to set up WebSocket message routing: {e}")
# =========================================
# ENHANCED TRAINING JOB STATUS EVENTS
# =========================================

View File

@@ -16,6 +16,15 @@ from app.services.training_orchestrator import TrainingDataOrchestrator
from app.core.database import get_db_session
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed,
publish_job_completed,
publish_job_failed
)
logger = logging.getLogger(__name__)
class TrainingService:
@@ -61,18 +70,12 @@ class TrainingService:
logger.info(f"Starting training job {job_id} for tenant {tenant_id}")
from app.services.messaging import TrainingStatusPublisher
status_publisher = TrainingStatusPublisher(job_id, tenant_id)
try:
await status_publisher.job_started({
"bakery_location": bakery_location,
"has_custom_dates": bool(requested_start or requested_end)
}, 0) # Will be updated when we know product count
# Step 1: Prepare training dataset with date alignment and orchestration
logger.info("Step 1: Preparing and aligning training data")
await publish_job_progress(job_id, tenant_id, 0, "Extrayendo datos de ventas")
training_dataset = await self.orchestrator.prepare_training_data(
tenant_id=tenant_id,
bakery_location=bakery_location,
@@ -83,6 +86,7 @@ class TrainingService:
# Step 2: Execute ML training pipeline
logger.info("Step 2: Starting ML training pipeline")
await publish_job_progress(job_id, tenant_id, 35, "Starting ML training pipeline")
training_results = await self.trainer.train_tenant_models(
tenant_id=tenant_id,
training_dataset=training_dataset,
@@ -110,12 +114,11 @@ class TrainingService:
}
logger.info(f"Training job {job_id} completed successfully")
await status_publisher.job_completed(final_result)
await publish_job_completed(job_id, tenant_id, final_result);
return TrainingService.create_detailed_training_response(final_result)
except Exception as e:
logger.error(f"Training job {job_id} failed: {str(e)}")
await status_publisher.job_failed(str(e))
# Return error response in same detailed format
final_result = {
"job_id": job_id,
@@ -139,7 +142,7 @@ class TrainingService:
"completed_at": datetime.now().isoformat(),
"error_message": str(e)
}
await publish_job_failed(job_id, tenant_id, str(e), final_result)
return TrainingService.create_detailed_training_response(final_result)
async def start_single_product_training(