Improve the event messaging for training service 2
This commit is contained in:
257
services/training/app/api/websocket.py
Normal file
257
services/training/app/api/websocket.py
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
# services/training/app/api/websocket.py
|
||||||
|
"""
|
||||||
|
WebSocket endpoints for real-time training progress updates
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any
|
||||||
|
from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException
|
||||||
|
from fastapi.routing import APIRouter
|
||||||
|
|
||||||
|
from app.services.messaging import training_publisher
|
||||||
|
from shared.auth.decorators import (
|
||||||
|
get_current_user_dep,
|
||||||
|
get_current_tenant_id_dep
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Create WebSocket router
|
||||||
|
websocket_router = APIRouter()
|
||||||
|
|
||||||
|
class ConnectionManager:
|
||||||
|
"""Manage WebSocket connections for training progress"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.active_connections: Dict[str, Dict[str, WebSocket]] = {}
|
||||||
|
# Structure: {job_id: {connection_id: websocket}}
|
||||||
|
|
||||||
|
async def connect(self, websocket: WebSocket, job_id: str, connection_id: str):
|
||||||
|
"""Accept WebSocket connection and register it"""
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
if job_id not in self.active_connections:
|
||||||
|
self.active_connections[job_id] = {}
|
||||||
|
|
||||||
|
self.active_connections[job_id][connection_id] = websocket
|
||||||
|
logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}")
|
||||||
|
|
||||||
|
def disconnect(self, job_id: str, connection_id: str):
|
||||||
|
"""Remove WebSocket connection"""
|
||||||
|
if job_id in self.active_connections:
|
||||||
|
self.active_connections[job_id].pop(connection_id, None)
|
||||||
|
if not self.active_connections[job_id]:
|
||||||
|
del self.active_connections[job_id]
|
||||||
|
|
||||||
|
logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}")
|
||||||
|
|
||||||
|
async def send_to_job(self, job_id: str, message: dict):
|
||||||
|
"""Send message to all connections for a specific job"""
|
||||||
|
if job_id not in self.active_connections:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Send to all connections for this job
|
||||||
|
disconnected_connections = []
|
||||||
|
|
||||||
|
for connection_id, websocket in self.active_connections[job_id].items():
|
||||||
|
try:
|
||||||
|
await websocket.send_json(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to send message to connection {connection_id}: {e}")
|
||||||
|
disconnected_connections.append(connection_id)
|
||||||
|
|
||||||
|
# Clean up disconnected connections
|
||||||
|
for connection_id in disconnected_connections:
|
||||||
|
self.disconnect(job_id, connection_id)
|
||||||
|
|
||||||
|
# Global connection manager
|
||||||
|
connection_manager = ConnectionManager()
|
||||||
|
|
||||||
|
@websocket_router.websocket("/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
||||||
|
async def training_progress_websocket(
|
||||||
|
websocket: WebSocket,
|
||||||
|
tenant_id: str,
|
||||||
|
job_id: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
WebSocket endpoint for real-time training progress updates
|
||||||
|
|
||||||
|
Message format sent to client:
|
||||||
|
{
|
||||||
|
"type": "progress" | "completed" | "failed" | "step_completed",
|
||||||
|
"job_id": "job_123",
|
||||||
|
"timestamp": "2025-07-30T19:08:53Z",
|
||||||
|
"data": {
|
||||||
|
"progress": 45,
|
||||||
|
"current_step": "Training model for bread",
|
||||||
|
"current_product": "bread",
|
||||||
|
"products_completed": 2,
|
||||||
|
"products_total": 5,
|
||||||
|
"estimated_time_remaining_minutes": 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
connection_id = f"{tenant_id}_{id(websocket)}"
|
||||||
|
|
||||||
|
# Accept connection
|
||||||
|
await connection_manager.connect(websocket, job_id, connection_id)
|
||||||
|
|
||||||
|
# Set up RabbitMQ consumer for this job
|
||||||
|
consumer_task = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start RabbitMQ consumer
|
||||||
|
consumer_task = asyncio.create_task(
|
||||||
|
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send initial status if available
|
||||||
|
try:
|
||||||
|
# You can fetch current job status from database here
|
||||||
|
initial_status = await get_current_job_status(job_id, tenant_id)
|
||||||
|
if initial_status:
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "initial_status",
|
||||||
|
"job_id": job_id,
|
||||||
|
"data": initial_status
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to send initial status: {e}")
|
||||||
|
|
||||||
|
# Keep connection alive and handle client messages
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Wait for client ping or other messages
|
||||||
|
message = await websocket.receive_text()
|
||||||
|
|
||||||
|
if message == "ping":
|
||||||
|
await websocket.send_text("pong")
|
||||||
|
elif message == "get_status":
|
||||||
|
# Send current status on demand
|
||||||
|
current_status = await get_current_job_status(job_id, tenant_id)
|
||||||
|
if current_status:
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "current_status",
|
||||||
|
"job_id": job_id,
|
||||||
|
"data": current_status
|
||||||
|
})
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info(f"WebSocket disconnected for job {job_id}")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket error for job {job_id}: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
connection_manager.disconnect(job_id, connection_id)
|
||||||
|
|
||||||
|
if consumer_task and not consumer_task.done():
|
||||||
|
consumer_task.cancel()
|
||||||
|
try:
|
||||||
|
await consumer_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
|
||||||
|
"""Set up RabbitMQ consumer to listen for training events for a specific job"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create a unique queue for this WebSocket connection
|
||||||
|
queue_name = f"websocket_training_{job_id}_{tenant_id}"
|
||||||
|
|
||||||
|
async def handle_training_message(message):
|
||||||
|
"""Handle incoming RabbitMQ messages and forward to WebSocket"""
|
||||||
|
try:
|
||||||
|
# Parse the message
|
||||||
|
body = message.body.decode()
|
||||||
|
data = json.loads(body)
|
||||||
|
|
||||||
|
# Extract event data
|
||||||
|
event_type = data.get("event_type", "unknown")
|
||||||
|
event_data = data.get("data", {})
|
||||||
|
|
||||||
|
# Only process messages for this specific job
|
||||||
|
if event_data.get("job_id") != job_id:
|
||||||
|
await message.ack()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Transform RabbitMQ message to WebSocket message format
|
||||||
|
websocket_message = {
|
||||||
|
"type": map_event_type_to_websocket_type(event_type),
|
||||||
|
"job_id": job_id,
|
||||||
|
"timestamp": data.get("timestamp"),
|
||||||
|
"data": event_data
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send to all WebSocket connections for this job
|
||||||
|
await connection_manager.send_to_job(job_id, websocket_message)
|
||||||
|
|
||||||
|
# Acknowledge the message
|
||||||
|
await message.ack()
|
||||||
|
|
||||||
|
logger.debug(f"Forwarded training event to WebSocket: {event_type}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling training message for WebSocket: {e}")
|
||||||
|
await message.nack(requeue=False)
|
||||||
|
|
||||||
|
# Subscribe to training events
|
||||||
|
await training_publisher.consume_events(
|
||||||
|
exchange_name="training.events",
|
||||||
|
queue_name=queue_name,
|
||||||
|
routing_key="training.*", # Listen to all training events
|
||||||
|
callback=handle_training_message
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to set up RabbitMQ consumer for WebSocket: {e}")
|
||||||
|
|
||||||
|
def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str:
|
||||||
|
"""Map RabbitMQ event types to WebSocket message types"""
|
||||||
|
mapping = {
|
||||||
|
"training.started": "started",
|
||||||
|
"training.progress": "progress",
|
||||||
|
"training.completed": "completed",
|
||||||
|
"training.failed": "failed",
|
||||||
|
"training.cancelled": "cancelled",
|
||||||
|
"training.step.completed": "step_completed",
|
||||||
|
"training.product.started": "product_started",
|
||||||
|
"training.product.completed": "product_completed",
|
||||||
|
"training.product.failed": "product_failed",
|
||||||
|
"training.model.trained": "model_trained",
|
||||||
|
"training.data.validation.started": "validation_started",
|
||||||
|
"training.data.validation.completed": "validation_completed"
|
||||||
|
}
|
||||||
|
|
||||||
|
return mapping.get(rabbitmq_event_type, "unknown")
|
||||||
|
|
||||||
|
async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get current job status from database or cache"""
|
||||||
|
try:
|
||||||
|
# This should query your database for current job status
|
||||||
|
# For now, return a placeholder - implement based on your database schema
|
||||||
|
|
||||||
|
from app.core.database import get_db_session
|
||||||
|
from app.models.training import ModelTrainingLog # Assuming you have this model
|
||||||
|
|
||||||
|
async with get_db_session() as db:
|
||||||
|
# Query your training job status
|
||||||
|
# This is a placeholder - adjust based on your actual database models
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Placeholder return - replace with actual database query
|
||||||
|
return {
|
||||||
|
"job_id": job_id,
|
||||||
|
"status": "running", # or "completed", "failed", etc.
|
||||||
|
"progress": 0,
|
||||||
|
"current_step": "Starting...",
|
||||||
|
"started_at": "2025-07-30T19:00:00Z"
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get current job status: {e}")
|
||||||
|
return None
|
||||||
@@ -18,6 +18,7 @@ import uvicorn
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import initialize_training_database, cleanup_training_database
|
from app.core.database import initialize_training_database, cleanup_training_database
|
||||||
from app.api import training, models
|
from app.api import training, models
|
||||||
|
from app.api.websocket import websocket_router
|
||||||
from app.services.messaging import setup_messaging, cleanup_messaging
|
from app.services.messaging import setup_messaging, cleanup_messaging
|
||||||
from shared.monitoring.logging import setup_logging
|
from shared.monitoring.logging import setup_logging
|
||||||
from shared.monitoring.metrics import MetricsCollector
|
from shared.monitoring.metrics import MetricsCollector
|
||||||
@@ -176,6 +177,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|||||||
# Include API routers
|
# Include API routers
|
||||||
app.include_router(training.router, prefix="/api/v1", tags=["training"])
|
app.include_router(training.router, prefix="/api/v1", tags=["training"])
|
||||||
app.include_router(models.router, prefix="/api/v1", tags=["models"])
|
app.include_router(models.router, prefix="/api/v1", tags=["models"])
|
||||||
|
app.include_router(websocket_router, prefix="/api/v1/ws", tags=["websocket"])
|
||||||
|
|
||||||
|
|
||||||
# Health check endpoints
|
# Health check endpoints
|
||||||
|
|||||||
@@ -18,6 +18,13 @@ from app.core.config import settings
|
|||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.services.messaging import (
|
||||||
|
publish_job_progress,
|
||||||
|
publish_data_validation_started,
|
||||||
|
publish_data_validation_completed,
|
||||||
|
publish_job_step_completed
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class BakeryMLTrainer:
|
class BakeryMLTrainer:
|
||||||
|
|||||||
@@ -63,6 +63,15 @@ def safe_json_serialize(data: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
"""
|
"""
|
||||||
return serialize_for_json(data)
|
return serialize_for_json(data)
|
||||||
|
|
||||||
|
async def setup_websocket_message_routing():
|
||||||
|
"""Set up message routing for WebSocket connections"""
|
||||||
|
try:
|
||||||
|
# This will be called from the WebSocket endpoint
|
||||||
|
# to set up the consumer for a specific job
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to set up WebSocket message routing: {e}")
|
||||||
|
|
||||||
# =========================================
|
# =========================================
|
||||||
# ENHANCED TRAINING JOB STATUS EVENTS
|
# ENHANCED TRAINING JOB STATUS EVENTS
|
||||||
# =========================================
|
# =========================================
|
||||||
|
|||||||
@@ -16,6 +16,15 @@ from app.services.training_orchestrator import TrainingDataOrchestrator
|
|||||||
|
|
||||||
from app.core.database import get_db_session
|
from app.core.database import get_db_session
|
||||||
|
|
||||||
|
from app.services.messaging import (
|
||||||
|
publish_job_progress,
|
||||||
|
publish_data_validation_started,
|
||||||
|
publish_data_validation_completed,
|
||||||
|
publish_job_step_completed,
|
||||||
|
publish_job_completed,
|
||||||
|
publish_job_failed
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class TrainingService:
|
class TrainingService:
|
||||||
@@ -61,18 +70,12 @@ class TrainingService:
|
|||||||
|
|
||||||
logger.info(f"Starting training job {job_id} for tenant {tenant_id}")
|
logger.info(f"Starting training job {job_id} for tenant {tenant_id}")
|
||||||
|
|
||||||
from app.services.messaging import TrainingStatusPublisher
|
|
||||||
status_publisher = TrainingStatusPublisher(job_id, tenant_id)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
await status_publisher.job_started({
|
|
||||||
"bakery_location": bakery_location,
|
|
||||||
"has_custom_dates": bool(requested_start or requested_end)
|
|
||||||
}, 0) # Will be updated when we know product count
|
|
||||||
|
|
||||||
# Step 1: Prepare training dataset with date alignment and orchestration
|
# Step 1: Prepare training dataset with date alignment and orchestration
|
||||||
logger.info("Step 1: Preparing and aligning training data")
|
logger.info("Step 1: Preparing and aligning training data")
|
||||||
|
await publish_job_progress(job_id, tenant_id, 0, "Extrayendo datos de ventas")
|
||||||
training_dataset = await self.orchestrator.prepare_training_data(
|
training_dataset = await self.orchestrator.prepare_training_data(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
bakery_location=bakery_location,
|
bakery_location=bakery_location,
|
||||||
@@ -83,6 +86,7 @@ class TrainingService:
|
|||||||
|
|
||||||
# Step 2: Execute ML training pipeline
|
# Step 2: Execute ML training pipeline
|
||||||
logger.info("Step 2: Starting ML training pipeline")
|
logger.info("Step 2: Starting ML training pipeline")
|
||||||
|
await publish_job_progress(job_id, tenant_id, 35, "Starting ML training pipeline")
|
||||||
training_results = await self.trainer.train_tenant_models(
|
training_results = await self.trainer.train_tenant_models(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
training_dataset=training_dataset,
|
training_dataset=training_dataset,
|
||||||
@@ -110,12 +114,11 @@ class TrainingService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Training job {job_id} completed successfully")
|
logger.info(f"Training job {job_id} completed successfully")
|
||||||
await status_publisher.job_completed(final_result)
|
await publish_job_completed(job_id, tenant_id, final_result);
|
||||||
return TrainingService.create_detailed_training_response(final_result)
|
return TrainingService.create_detailed_training_response(final_result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Training job {job_id} failed: {str(e)}")
|
logger.error(f"Training job {job_id} failed: {str(e)}")
|
||||||
await status_publisher.job_failed(str(e))
|
|
||||||
# Return error response in same detailed format
|
# Return error response in same detailed format
|
||||||
final_result = {
|
final_result = {
|
||||||
"job_id": job_id,
|
"job_id": job_id,
|
||||||
@@ -139,7 +142,7 @@ class TrainingService:
|
|||||||
"completed_at": datetime.now().isoformat(),
|
"completed_at": datetime.now().isoformat(),
|
||||||
"error_message": str(e)
|
"error_message": str(e)
|
||||||
}
|
}
|
||||||
|
await publish_job_failed(job_id, tenant_id, str(e), final_result)
|
||||||
return TrainingService.create_detailed_training_response(final_result)
|
return TrainingService.create_detailed_training_response(final_result)
|
||||||
|
|
||||||
async def start_single_product_training(
|
async def start_single_product_training(
|
||||||
|
|||||||
Reference in New Issue
Block a user