diff --git a/services/training/app/api/websocket.py b/services/training/app/api/websocket.py new file mode 100644 index 00000000..dddd7317 --- /dev/null +++ b/services/training/app/api/websocket.py @@ -0,0 +1,257 @@ +# services/training/app/api/websocket.py +""" +WebSocket endpoints for real-time training progress updates +""" + +import json +import asyncio +import logging +from typing import Dict, Any +from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException +from fastapi.routing import APIRouter + +from app.services.messaging import training_publisher +from shared.auth.decorators import ( + get_current_user_dep, + get_current_tenant_id_dep +) + +logger = logging.getLogger(__name__) + +# Create WebSocket router +websocket_router = APIRouter() + +class ConnectionManager: + """Manage WebSocket connections for training progress""" + + def __init__(self): + self.active_connections: Dict[str, Dict[str, WebSocket]] = {} + # Structure: {job_id: {connection_id: websocket}} + + async def connect(self, websocket: WebSocket, job_id: str, connection_id: str): + """Accept WebSocket connection and register it""" + await websocket.accept() + + if job_id not in self.active_connections: + self.active_connections[job_id] = {} + + self.active_connections[job_id][connection_id] = websocket + logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}") + + def disconnect(self, job_id: str, connection_id: str): + """Remove WebSocket connection""" + if job_id in self.active_connections: + self.active_connections[job_id].pop(connection_id, None) + if not self.active_connections[job_id]: + del self.active_connections[job_id] + + logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}") + + async def send_to_job(self, job_id: str, message: dict): + """Send message to all connections for a specific job""" + if job_id not in self.active_connections: + return + + # Send to all connections for this job + disconnected_connections = [] + + for connection_id, websocket in self.active_connections[job_id].items(): + try: + await websocket.send_json(message) + except Exception as e: + logger.warning(f"Failed to send message to connection {connection_id}: {e}") + disconnected_connections.append(connection_id) + + # Clean up disconnected connections + for connection_id in disconnected_connections: + self.disconnect(job_id, connection_id) + +# Global connection manager +connection_manager = ConnectionManager() + +@websocket_router.websocket("/tenants/{tenant_id}/training/jobs/{job_id}/live") +async def training_progress_websocket( + websocket: WebSocket, + tenant_id: str, + job_id: str +): + """ + WebSocket endpoint for real-time training progress updates + + Message format sent to client: + { + "type": "progress" | "completed" | "failed" | "step_completed", + "job_id": "job_123", + "timestamp": "2025-07-30T19:08:53Z", + "data": { + "progress": 45, + "current_step": "Training model for bread", + "current_product": "bread", + "products_completed": 2, + "products_total": 5, + "estimated_time_remaining_minutes": 8 + } + } + """ + connection_id = f"{tenant_id}_{id(websocket)}" + + # Accept connection + await connection_manager.connect(websocket, job_id, connection_id) + + # Set up RabbitMQ consumer for this job + consumer_task = None + + try: + # Start RabbitMQ consumer + consumer_task = asyncio.create_task( + setup_rabbitmq_consumer_for_job(job_id, tenant_id) + ) + + # Send initial status if available + try: + # You can fetch current job status from database here + initial_status = await get_current_job_status(job_id, tenant_id) + if initial_status: + await websocket.send_json({ + "type": "initial_status", + "job_id": job_id, + "data": initial_status + }) + except Exception as e: + logger.warning(f"Failed to send initial status: {e}") + + # Keep connection alive and handle client messages + while True: + try: + # Wait for client ping or other messages + message = await websocket.receive_text() + + if message == "ping": + await websocket.send_text("pong") + elif message == "get_status": + # Send current status on demand + current_status = await get_current_job_status(job_id, tenant_id) + if current_status: + await websocket.send_json({ + "type": "current_status", + "job_id": job_id, + "data": current_status + }) + + except WebSocketDisconnect: + logger.info(f"WebSocket disconnected for job {job_id}") + break + except Exception as e: + logger.error(f"WebSocket error for job {job_id}: {e}") + break + + finally: + # Clean up + connection_manager.disconnect(job_id, connection_id) + + if consumer_task and not consumer_task.done(): + consumer_task.cancel() + try: + await consumer_task + except asyncio.CancelledError: + pass + +async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str): + """Set up RabbitMQ consumer to listen for training events for a specific job""" + + try: + # Create a unique queue for this WebSocket connection + queue_name = f"websocket_training_{job_id}_{tenant_id}" + + async def handle_training_message(message): + """Handle incoming RabbitMQ messages and forward to WebSocket""" + try: + # Parse the message + body = message.body.decode() + data = json.loads(body) + + # Extract event data + event_type = data.get("event_type", "unknown") + event_data = data.get("data", {}) + + # Only process messages for this specific job + if event_data.get("job_id") != job_id: + await message.ack() + return + + # Transform RabbitMQ message to WebSocket message format + websocket_message = { + "type": map_event_type_to_websocket_type(event_type), + "job_id": job_id, + "timestamp": data.get("timestamp"), + "data": event_data + } + + # Send to all WebSocket connections for this job + await connection_manager.send_to_job(job_id, websocket_message) + + # Acknowledge the message + await message.ack() + + logger.debug(f"Forwarded training event to WebSocket: {event_type}") + + except Exception as e: + logger.error(f"Error handling training message for WebSocket: {e}") + await message.nack(requeue=False) + + # Subscribe to training events + await training_publisher.consume_events( + exchange_name="training.events", + queue_name=queue_name, + routing_key="training.*", # Listen to all training events + callback=handle_training_message + ) + + except Exception as e: + logger.error(f"Failed to set up RabbitMQ consumer for WebSocket: {e}") + +def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str: + """Map RabbitMQ event types to WebSocket message types""" + mapping = { + "training.started": "started", + "training.progress": "progress", + "training.completed": "completed", + "training.failed": "failed", + "training.cancelled": "cancelled", + "training.step.completed": "step_completed", + "training.product.started": "product_started", + "training.product.completed": "product_completed", + "training.product.failed": "product_failed", + "training.model.trained": "model_trained", + "training.data.validation.started": "validation_started", + "training.data.validation.completed": "validation_completed" + } + + return mapping.get(rabbitmq_event_type, "unknown") + +async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]: + """Get current job status from database or cache""" + try: + # This should query your database for current job status + # For now, return a placeholder - implement based on your database schema + + from app.core.database import get_db_session + from app.models.training import ModelTrainingLog # Assuming you have this model + + async with get_db_session() as db: + # Query your training job status + # This is a placeholder - adjust based on your actual database models + pass + + # Placeholder return - replace with actual database query + return { + "job_id": job_id, + "status": "running", # or "completed", "failed", etc. + "progress": 0, + "current_step": "Starting...", + "started_at": "2025-07-30T19:00:00Z" + } + + except Exception as e: + logger.error(f"Failed to get current job status: {e}") + return None \ No newline at end of file diff --git a/services/training/app/main.py b/services/training/app/main.py index 256a1766..f7f1aac9 100644 --- a/services/training/app/main.py +++ b/services/training/app/main.py @@ -18,6 +18,7 @@ import uvicorn from app.core.config import settings from app.core.database import initialize_training_database, cleanup_training_database from app.api import training, models +from app.api.websocket import websocket_router from app.services.messaging import setup_messaging, cleanup_messaging from shared.monitoring.logging import setup_logging from shared.monitoring.metrics import MetricsCollector @@ -176,6 +177,7 @@ async def global_exception_handler(request: Request, exc: Exception): # Include API routers app.include_router(training.router, prefix="/api/v1", tags=["training"]) app.include_router(models.router, prefix="/api/v1", tags=["models"]) +app.include_router(websocket_router, prefix="/api/v1/ws", tags=["websocket"]) # Health check endpoints diff --git a/services/training/app/ml/trainer.py b/services/training/app/ml/trainer.py index d2d30bcb..530b6126 100644 --- a/services/training/app/ml/trainer.py +++ b/services/training/app/ml/trainer.py @@ -18,6 +18,13 @@ from app.core.config import settings from sqlalchemy.ext.asyncio import AsyncSession +from app.services.messaging import ( + publish_job_progress, + publish_data_validation_started, + publish_data_validation_completed, + publish_job_step_completed +) + logger = logging.getLogger(__name__) class BakeryMLTrainer: diff --git a/services/training/app/services/messaging.py b/services/training/app/services/messaging.py index f6cebd34..749370f1 100644 --- a/services/training/app/services/messaging.py +++ b/services/training/app/services/messaging.py @@ -63,6 +63,15 @@ def safe_json_serialize(data: Dict[str, Any]) -> Dict[str, Any]: """ return serialize_for_json(data) +async def setup_websocket_message_routing(): + """Set up message routing for WebSocket connections""" + try: + # This will be called from the WebSocket endpoint + # to set up the consumer for a specific job + pass + except Exception as e: + logger.error(f"Failed to set up WebSocket message routing: {e}") + # ========================================= # ENHANCED TRAINING JOB STATUS EVENTS # ========================================= diff --git a/services/training/app/services/training_service.py b/services/training/app/services/training_service.py index 03f329e1..2152724d 100644 --- a/services/training/app/services/training_service.py +++ b/services/training/app/services/training_service.py @@ -16,6 +16,15 @@ from app.services.training_orchestrator import TrainingDataOrchestrator from app.core.database import get_db_session +from app.services.messaging import ( + publish_job_progress, + publish_data_validation_started, + publish_data_validation_completed, + publish_job_step_completed, + publish_job_completed, + publish_job_failed +) + logger = logging.getLogger(__name__) class TrainingService: @@ -61,18 +70,12 @@ class TrainingService: logger.info(f"Starting training job {job_id} for tenant {tenant_id}") - from app.services.messaging import TrainingStatusPublisher - status_publisher = TrainingStatusPublisher(job_id, tenant_id) try: - await status_publisher.job_started({ - "bakery_location": bakery_location, - "has_custom_dates": bool(requested_start or requested_end) - }, 0) # Will be updated when we know product count - # Step 1: Prepare training dataset with date alignment and orchestration logger.info("Step 1: Preparing and aligning training data") + await publish_job_progress(job_id, tenant_id, 0, "Extrayendo datos de ventas") training_dataset = await self.orchestrator.prepare_training_data( tenant_id=tenant_id, bakery_location=bakery_location, @@ -83,6 +86,7 @@ class TrainingService: # Step 2: Execute ML training pipeline logger.info("Step 2: Starting ML training pipeline") + await publish_job_progress(job_id, tenant_id, 35, "Starting ML training pipeline") training_results = await self.trainer.train_tenant_models( tenant_id=tenant_id, training_dataset=training_dataset, @@ -110,12 +114,11 @@ class TrainingService: } logger.info(f"Training job {job_id} completed successfully") - await status_publisher.job_completed(final_result) + await publish_job_completed(job_id, tenant_id, final_result); return TrainingService.create_detailed_training_response(final_result) except Exception as e: logger.error(f"Training job {job_id} failed: {str(e)}") - await status_publisher.job_failed(str(e)) # Return error response in same detailed format final_result = { "job_id": job_id, @@ -139,7 +142,7 @@ class TrainingService: "completed_at": datetime.now().isoformat(), "error_message": str(e) } - + await publish_job_failed(job_id, tenant_id, str(e), final_result) return TrainingService.create_detailed_training_response(final_result) async def start_single_product_training(