REFACTOR external service and improve websocket training

This commit is contained in:
Urtzi Alfaro
2025-10-09 14:11:02 +02:00
parent 7c72f83c51
commit 3c689b4f98
111 changed files with 13289 additions and 2374 deletions

View File

@@ -1,21 +1,18 @@
"""
Training Operations API - BUSINESS logic
Handles training job execution, metrics, and WebSocket live feed
Handles training job execution and metrics
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path, WebSocket, WebSocketDisconnect
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path
from typing import Optional, Dict, Any
import structlog
import asyncio
import json
import datetime
from shared.auth.access_control import require_user_role, admin_role_required, analytics_tier_required
from datetime import datetime, timezone
import uuid
from shared.routing import RouteBuilder
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from shared.database.base import create_database_manager
from datetime import datetime, timezone
import uuid
from app.services.training_service import EnhancedTrainingService
from app.schemas.training import (
@@ -23,15 +20,10 @@ from app.schemas.training import (
SingleProductTrainingRequest,
TrainingJobResponse
)
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed,
publish_job_completed,
publish_job_failed,
publish_job_started,
training_publisher
from app.services.training_events import (
publish_training_started,
publish_training_completed,
publish_training_failed
)
from app.core.config import settings
@@ -85,6 +77,14 @@ async def start_training_job(
if metrics:
metrics.increment_counter("enhanced_training_jobs_created_total")
# Publish training.started event immediately so WebSocket clients
# have initial state when they connect
await publish_training_started(
job_id=job_id,
tenant_id=tenant_id,
total_products=0 # Will be updated when actual training starts
)
# Add enhanced background task
background_tasks.add_task(
execute_training_job_background,
@@ -190,12 +190,8 @@ async def execute_training_job_background(
tenant_id=tenant_id
)
# Publish job started event
await publish_job_started(job_id, tenant_id, {
"enhanced_features": True,
"repository_pattern": True,
"job_type": "enhanced_training"
})
# This will be published by the training service itself
# when it starts execution
training_config = {
"job_id": job_id,
@@ -241,16 +237,7 @@ async def execute_training_job_background(
tenant_id=tenant_id
)
# Publish enhanced completion event
await publish_job_completed(
job_id=job_id,
tenant_id=tenant_id,
results={
**result,
"enhanced_features": True,
"repository_integration": True
}
)
# Completion event is published by the training service
logger.info("Enhanced background training job completed successfully",
job_id=job_id,
@@ -276,17 +263,8 @@ async def execute_training_job_background(
job_id=job_id,
status_error=str(status_error))
# Publish enhanced failure event
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error=str(training_error),
metadata={
"enhanced_features": True,
"repository_pattern": True,
"error_type": type(training_error).__name__
}
)
# Failure event is published by the training service
await publish_training_failed(job_id, tenant_id, str(training_error))
except Exception as background_error:
logger.error("Critical error in enhanced background training job",
@@ -370,373 +348,19 @@ async def start_single_product_training(
)
# ============================================
# WebSocket Live Feed
# ============================================
class ConnectionManager:
"""Manage WebSocket connections for training progress"""
def __init__(self):
self.active_connections: Dict[str, Dict[str, WebSocket]] = {}
# Structure: {job_id: {connection_id: websocket}}
async def connect(self, websocket: WebSocket, job_id: str, connection_id: str):
"""Accept WebSocket connection and register it"""
await websocket.accept()
if job_id not in self.active_connections:
self.active_connections[job_id] = {}
self.active_connections[job_id][connection_id] = websocket
logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}")
def disconnect(self, job_id: str, connection_id: str):
"""Remove WebSocket connection"""
if job_id in self.active_connections:
self.active_connections[job_id].pop(connection_id, None)
if not self.active_connections[job_id]:
del self.active_connections[job_id]
logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}")
async def send_to_job(self, job_id: str, message: dict):
"""Send message to all connections for a specific job with better error handling"""
if job_id not in self.active_connections:
logger.debug(f"No active connections for job {job_id}")
return
# Send to all connections for this job
disconnected_connections = []
for connection_id, websocket in self.active_connections[job_id].items():
try:
await websocket.send_json(message)
logger.debug(f"Sent {message.get('type', 'unknown')} to connection {connection_id}")
except Exception as e:
logger.warning(f"Failed to send message to connection {connection_id}: {e}")
disconnected_connections.append(connection_id)
# Clean up disconnected connections
for connection_id in disconnected_connections:
self.disconnect(job_id, connection_id)
# Log successful sends
active_count = len(self.active_connections.get(job_id, {}))
if active_count > 0:
logger.info(f"Sent {message.get('type', 'unknown')} message to {active_count} connection(s) for job {job_id}")
# Global connection manager
connection_manager = ConnectionManager()
@router.websocket(route_builder.build_nested_resource_route('jobs', 'job_id', 'live'))
async def training_progress_websocket(
websocket: WebSocket,
tenant_id: str,
job_id: str
):
"""
WebSocket endpoint for real-time training progress updates
"""
# Validate token from query parameters
token = websocket.query_params.get("token")
if not token:
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
await websocket.close(code=1008, reason="Authentication token required")
return
# Validate the token
from shared.auth.jwt_handler import JWTHandler
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload:
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
await websocket.close(code=1008, reason="Invalid authentication token")
return
# Verify user has access to this tenant
user_id = payload.get('user_id')
if not user_id:
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
await websocket.close(code=1008, reason="Invalid token payload")
return
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
except Exception as e:
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
await websocket.close(code=1008, reason="Token validation failed")
return
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
await connection_manager.connect(websocket, job_id, connection_id)
logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
# Send immediate connection confirmation to prevent gateway timeout
try:
await websocket.send_json({
"type": "connected",
"job_id": job_id,
"message": "WebSocket connection established",
"timestamp": str(datetime.now())
})
logger.debug(f"Sent connection confirmation for job {job_id}")
except Exception as e:
logger.error(f"Failed to send connection confirmation for job {job_id}: {e}")
consumer_task = None
training_completed = False
try:
# Start RabbitMQ consumer
consumer_task = asyncio.create_task(
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
)
last_activity = asyncio.get_event_loop().time()
while not training_completed:
try:
try:
data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
last_activity = asyncio.get_event_loop().time()
# Handle different message types
if data["type"] == "websocket.receive":
if "text" in data:
message_text = data["text"]
if message_text == "ping":
await websocket.send_text("pong")
logger.debug(f"Text ping received from job {job_id}")
elif message_text == "get_status":
current_status = await get_current_job_status(job_id, tenant_id)
if current_status:
await websocket.send_json({
"type": "current_status",
"job_id": job_id,
"data": current_status
})
elif message_text == "close":
logger.info(f"Client requested connection close for job {job_id}")
break
elif "bytes" in data:
await websocket.send_text("pong")
logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
elif data["type"] == "websocket.disconnect":
logger.info(f"WebSocket disconnect message received for job {job_id}")
break
except asyncio.TimeoutError:
current_time = asyncio.get_event_loop().time()
if current_time - last_activity > 90:
logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
try:
await websocket.send_json({
"type": "heartbeat",
"job_id": job_id,
"timestamp": str(datetime.now()),
"message": "Training service heartbeat - frontend inactive",
"inactivity_seconds": int(current_time - last_activity)
})
last_activity = current_time
except Exception as e:
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
break
else:
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
continue
except WebSocketDisconnect:
logger.info(f"WebSocket client disconnected for job {job_id}")
break
except Exception as e:
logger.error(f"WebSocket error for job {job_id}: {e}")
if "Cannot call" in str(e) and "disconnect message" in str(e):
logger.error(f"FastAPI WebSocket disconnect error - connection already closed")
break
await asyncio.sleep(1)
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
except Exception as e:
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
finally:
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
connection_manager.disconnect(job_id, connection_id)
if consumer_task and not consumer_task.done():
if training_completed:
logger.info(f"Training completed, cancelling consumer for job {job_id}")
consumer_task.cancel()
else:
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
try:
await consumer_task
except asyncio.CancelledError:
logger.info(f"Consumer task cancelled for job {job_id}")
except Exception as e:
logger.error(f"Consumer task error for job {job_id}: {e}")
async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
"""Set up RabbitMQ consumer to listen for training events for a specific job"""
logger.info(f"Setting up RabbitMQ consumer for job {job_id}")
try:
# Create a unique queue for this WebSocket connection
queue_name = f"websocket_training_{job_id}_{tenant_id}"
async def handle_training_message(message):
"""Handle incoming RabbitMQ messages and forward to WebSocket"""
try:
# Parse the message
body = message.body.decode()
data = json.loads(body)
logger.debug(f"Received message for job {job_id}: {data.get('event_type', 'unknown')}")
# Extract event data
event_type = data.get("event_type", "unknown")
event_data = data.get("data", {})
# Only process messages for this specific job
message_job_id = event_data.get("job_id") if event_data else None
if message_job_id != job_id:
logger.debug(f"Ignoring message for different job: {message_job_id}")
await message.ack()
return
# Transform RabbitMQ message to WebSocket message format
websocket_message = {
"type": map_event_type_to_websocket_type(event_type),
"job_id": job_id,
"timestamp": data.get("timestamp"),
"data": event_data
}
logger.info(f"Forwarding {event_type} message to WebSocket clients for job {job_id}")
# Send to all WebSocket connections for this job
await connection_manager.send_to_job(job_id, websocket_message)
# Check if this is a completion message
if event_type in ["training.completed", "training.failed"]:
logger.info(f"Training completion detected for job {job_id}: {event_type}")
# Acknowledge the message
await message.ack()
logger.debug(f"Successfully processed {event_type} for job {job_id}")
except Exception as e:
logger.error(f"Error handling training message for job {job_id}: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
await message.nack(requeue=False)
# Check if training_publisher is connected
if not training_publisher.connected:
logger.warning(f"Training publisher not connected for job {job_id}, attempting to connect...")
success = await training_publisher.connect()
if not success:
logger.error(f"Failed to connect training_publisher for job {job_id}")
return
# Subscribe to training events
logger.info(f"Subscribing to training events for job {job_id}")
success = await training_publisher.consume_events(
exchange_name="training.events",
queue_name=queue_name,
routing_key="training.*",
callback=handle_training_message
)
if success:
logger.info(f"Successfully set up RabbitMQ consumer for job {job_id} (queue: {queue_name})")
# Keep the consumer running indefinitely until cancelled
try:
while True:
await asyncio.sleep(10)
logger.debug(f"Consumer heartbeat for job {job_id}")
except asyncio.CancelledError:
logger.info(f"Consumer cancelled for job {job_id}")
raise
except Exception as e:
logger.error(f"Consumer error for job {job_id}: {e}")
raise
else:
logger.error(f"Failed to set up RabbitMQ consumer for job {job_id}")
except Exception as e:
logger.error(f"Exception in setup_rabbitmq_consumer_for_job for job {job_id}: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str:
"""Map RabbitMQ event types to WebSocket message types"""
mapping = {
"training.started": "started",
"training.progress": "progress",
"training.completed": "completed",
"training.failed": "failed",
"training.cancelled": "cancelled",
"training.step.completed": "step_completed",
"training.product.started": "product_started",
"training.product.completed": "product_completed",
"training.product.failed": "product_failed",
"training.model.trained": "model_trained",
"training.data.validation.started": "validation_started",
"training.data.validation.completed": "validation_completed"
}
return mapping.get(rabbitmq_event_type, "unknown")
async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
"""Get current job status from database"""
try:
return {
"job_id": job_id,
"status": "running",
"progress": 0,
"current_step": "Starting...",
"started_at": "2025-07-30T19:00:00Z"
}
except Exception as e:
logger.error(f"Failed to get current job status: {e}")
return None
@router.get("/health")
async def health_check():
"""Health check endpoint for the training operations"""
return {
"status": "healthy",
"service": "training-operations",
"version": "2.0.0",
"version": "3.0.0",
"features": [
"repository-pattern",
"dependency-injection",
"enhanced-error-handling",
"metrics-tracking",
"transactional-operations",
"websocket-support"
"transactional-operations"
],
"timestamp": datetime.now().isoformat()
}