Files
bakery-ia/services/training/app/api/training_operations.py
2025-10-06 15:27:01 +02:00

731 lines
28 KiB
Python

"""
Training Operations API - BUSINESS logic
Handles training job execution, metrics, and WebSocket live feed
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path, WebSocket, WebSocketDisconnect
from typing import List, Optional, Dict, Any
import structlog
import asyncio
import json
import datetime
from shared.auth.access_control import require_user_role, admin_role_required, analytics_tier_required
from shared.routing import RouteBuilder
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from shared.database.base import create_database_manager
from datetime import datetime, timezone
import uuid
from app.services.training_service import EnhancedTrainingService
from app.schemas.training import (
TrainingJobRequest,
SingleProductTrainingRequest,
TrainingJobResponse
)
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed,
publish_job_completed,
publish_job_failed,
publish_job_started,
training_publisher
)
from app.core.config import settings
logger = structlog.get_logger()
route_builder = RouteBuilder('training')
router = APIRouter(tags=["training-operations"])
def get_enhanced_training_service():
"""Dependency injection for EnhancedTrainingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
return EnhancedTrainingService(database_manager)
@router.post(
route_builder.build_base_route("jobs"), response_model=TrainingJobResponse)
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
async def start_training_job(
request: TrainingJobRequest,
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start a new training job for all tenant products using repository pattern.
Enhanced immediate response pattern:
1. Validate request with enhanced validation
2. Create job record using repository pattern
3. Return 200 with enhanced job details
4. Execute enhanced training in background with repository tracking
Enhanced features:
- Repository pattern for data access
- Enhanced error handling and logging
- Metrics tracking and monitoring
- Transactional operations
"""
metrics = get_metrics_collector(request_obj)
try:
# Generate enhanced job ID
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info("Creating enhanced training job using repository pattern",
job_id=job_id,
tenant_id=tenant_id)
# Record job creation metrics
if metrics:
metrics.increment_counter("enhanced_training_jobs_created_total")
# Add enhanced background task
background_tasks.add_task(
execute_training_job_background,
tenant_id=tenant_id,
job_id=job_id,
bakery_location=(40.4168, -3.7038),
requested_start=request.start_date,
requested_end=request.end_date
)
# Return enhanced immediate success response
response_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending",
"message": "Enhanced training job started successfully using repository pattern",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": 18,
"training_results": {
"total_products": 0,
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
"overall_training_time_seconds": 0.0
},
"data_summary": None,
"completed_at": None,
"error_details": None,
"processing_metadata": {
"background_task": True,
"async_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"dependency_injection": True
}
}
logger.info("Enhanced training job queued successfully",
job_id=job_id,
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
return TrainingJobResponse(**response_data)
except HTTPException:
raise
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_training_validation_errors_total")
logger.error("Enhanced training job validation error",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_training_job_errors_total")
logger.error("Failed to queue enhanced training job",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to start enhanced training job"
)
async def execute_training_job_background(
tenant_id: str,
job_id: str,
bakery_location: tuple,
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None
):
"""
Enhanced background task that executes the training job using repository pattern.
Enhanced features:
- Repository pattern for all data operations
- Enhanced error handling with structured logging
- Transactional operations for data consistency
- Comprehensive metrics tracking
- Database connection pooling
- Enhanced progress reporting
"""
logger.info("Enhanced background training job started",
job_id=job_id,
tenant_id=tenant_id,
features=["repository-pattern", "enhanced-tracking"])
# Get enhanced training service with dependency injection
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
enhanced_training_service = EnhancedTrainingService(database_manager)
try:
# Create initial training log entry first
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="pending",
progress=0,
current_step="Starting enhanced training job",
tenant_id=tenant_id
)
# Publish job started event
await publish_job_started(job_id, tenant_id, {
"enhanced_features": True,
"repository_pattern": True,
"job_type": "enhanced_training"
})
training_config = {
"job_id": job_id,
"tenant_id": tenant_id,
"bakery_location": {
"latitude": bakery_location[0],
"longitude": bakery_location[1]
},
"requested_start": requested_start.isoformat() if requested_start else None,
"requested_end": requested_end.isoformat() if requested_end else None,
"estimated_duration_minutes": 18,
"background_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"api_version": "enhanced_v1"
}
# Update job status using repository pattern
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Initializing enhanced training pipeline",
tenant_id=tenant_id
)
# Execute the enhanced training pipeline with repository pattern
result = await enhanced_training_service.start_training_job(
tenant_id=tenant_id,
job_id=job_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end
)
# Update final status using repository pattern
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="completed",
progress=100,
current_step="Enhanced training completed successfully",
results=result,
tenant_id=tenant_id
)
# Publish enhanced completion event
await publish_job_completed(
job_id=job_id,
tenant_id=tenant_id,
results={
**result,
"enhanced_features": True,
"repository_integration": True
}
)
logger.info("Enhanced background training job completed successfully",
job_id=job_id,
models_created=result.get('products_trained', 0),
features=["repository-pattern", "enhanced-tracking"])
except Exception as training_error:
logger.error("Enhanced training pipeline failed",
job_id=job_id,
error=str(training_error))
try:
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Enhanced training failed",
error_message=str(training_error),
tenant_id=tenant_id
)
except Exception as status_error:
logger.error("Failed to update job status after training error",
job_id=job_id,
status_error=str(status_error))
# Publish enhanced failure event
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error=str(training_error),
metadata={
"enhanced_features": True,
"repository_pattern": True,
"error_type": type(training_error).__name__
}
)
except Exception as background_error:
logger.error("Critical error in enhanced background training job",
job_id=job_id,
error=str(background_error))
finally:
logger.info("Enhanced background training job cleanup completed",
job_id=job_id)
@router.post(
route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse)
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
async def start_single_product_training(
request: SingleProductTrainingRequest,
tenant_id: str = Path(..., description="Tenant ID"),
inventory_product_id: str = Path(..., description="Inventory product UUID"),
request_obj: Request = None,
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start enhanced training for a single product using repository pattern.
Enhanced features:
- Repository pattern for data access
- Enhanced error handling and validation
- Metrics tracking
- Transactional operations
"""
metrics = get_metrics_collector(request_obj)
try:
logger.info("Starting enhanced single product training",
inventory_product_id=inventory_product_id,
tenant_id=tenant_id)
# Record metrics
if metrics:
metrics.increment_counter("enhanced_single_product_training_total")
# Generate enhanced job ID
job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
# Delegate to enhanced training service
result = await enhanced_training_service.start_single_product_training(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id,
bakery_location=request.bakery_location or (40.4168, -3.7038)
)
if metrics:
metrics.increment_counter("enhanced_single_product_training_success_total")
logger.info("Enhanced single product training completed",
inventory_product_id=inventory_product_id,
job_id=job_id)
return TrainingJobResponse(**result)
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_single_product_validation_errors_total")
logger.error("Enhanced single product training validation error",
error=str(e),
inventory_product_id=inventory_product_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_single_product_training_errors_total")
logger.error("Enhanced single product training failed",
error=str(e),
inventory_product_id=inventory_product_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Enhanced single product training failed"
)
# ============================================
# WebSocket Live Feed
# ============================================
class ConnectionManager:
"""Manage WebSocket connections for training progress"""
def __init__(self):
self.active_connections: Dict[str, Dict[str, WebSocket]] = {}
# Structure: {job_id: {connection_id: websocket}}
async def connect(self, websocket: WebSocket, job_id: str, connection_id: str):
"""Accept WebSocket connection and register it"""
await websocket.accept()
if job_id not in self.active_connections:
self.active_connections[job_id] = {}
self.active_connections[job_id][connection_id] = websocket
logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}")
def disconnect(self, job_id: str, connection_id: str):
"""Remove WebSocket connection"""
if job_id in self.active_connections:
self.active_connections[job_id].pop(connection_id, None)
if not self.active_connections[job_id]:
del self.active_connections[job_id]
logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}")
async def send_to_job(self, job_id: str, message: dict):
"""Send message to all connections for a specific job with better error handling"""
if job_id not in self.active_connections:
logger.debug(f"No active connections for job {job_id}")
return
# Send to all connections for this job
disconnected_connections = []
for connection_id, websocket in self.active_connections[job_id].items():
try:
await websocket.send_json(message)
logger.debug(f"Sent {message.get('type', 'unknown')} to connection {connection_id}")
except Exception as e:
logger.warning(f"Failed to send message to connection {connection_id}: {e}")
disconnected_connections.append(connection_id)
# Clean up disconnected connections
for connection_id in disconnected_connections:
self.disconnect(job_id, connection_id)
# Log successful sends
active_count = len(self.active_connections.get(job_id, {}))
if active_count > 0:
logger.info(f"Sent {message.get('type', 'unknown')} message to {active_count} connection(s) for job {job_id}")
# Global connection manager
connection_manager = ConnectionManager()
@router.websocket(route_builder.build_nested_resource_route('jobs', 'job_id', 'live'))
async def training_progress_websocket(
websocket: WebSocket,
tenant_id: str,
job_id: str
):
"""
WebSocket endpoint for real-time training progress updates
"""
# Validate token from query parameters
token = websocket.query_params.get("token")
if not token:
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
await websocket.close(code=1008, reason="Authentication token required")
return
# Validate the token
from shared.auth.jwt_handler import JWTHandler
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload:
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
await websocket.close(code=1008, reason="Invalid authentication token")
return
# Verify user has access to this tenant
user_id = payload.get('user_id')
if not user_id:
logger.warning(f"WebSocket connection rejected - no user_id in token for job {job_id}")
await websocket.close(code=1008, reason="Invalid token payload")
return
logger.info(f"WebSocket authenticated for user {payload.get('email', 'unknown')} on job {job_id}")
except Exception as e:
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
await websocket.close(code=1008, reason="Token validation failed")
return
connection_id = f"{tenant_id}_{user_id}_{id(websocket)}"
await connection_manager.connect(websocket, job_id, connection_id)
logger.info(f"WebSocket connection established for job {job_id}, user {user_id}")
consumer_task = None
training_completed = False
try:
# Start RabbitMQ consumer
consumer_task = asyncio.create_task(
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
)
last_activity = asyncio.get_event_loop().time()
while not training_completed:
try:
try:
data = await asyncio.wait_for(websocket.receive(), timeout=60.0)
last_activity = asyncio.get_event_loop().time()
# Handle different message types
if data["type"] == "websocket.receive":
if "text" in data:
message_text = data["text"]
if message_text == "ping":
await websocket.send_text("pong")
logger.debug(f"Text ping received from job {job_id}")
elif message_text == "get_status":
current_status = await get_current_job_status(job_id, tenant_id)
if current_status:
await websocket.send_json({
"type": "current_status",
"job_id": job_id,
"data": current_status
})
elif message_text == "close":
logger.info(f"Client requested connection close for job {job_id}")
break
elif "bytes" in data:
await websocket.send_text("pong")
logger.debug(f"Binary ping received for job {job_id}, responding with text pong")
elif data["type"] == "websocket.disconnect":
logger.info(f"WebSocket disconnect message received for job {job_id}")
break
except asyncio.TimeoutError:
current_time = asyncio.get_event_loop().time()
if current_time - last_activity > 90:
logger.warning(f"No frontend activity for 90s on job {job_id}, sending training service heartbeat")
try:
await websocket.send_json({
"type": "heartbeat",
"job_id": job_id,
"timestamp": str(datetime.now()),
"message": "Training service heartbeat - frontend inactive",
"inactivity_seconds": int(current_time - last_activity)
})
last_activity = current_time
except Exception as e:
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
break
else:
logger.debug(f"Normal 60s timeout for job {job_id}, continuing (last activity: {int(current_time - last_activity)}s ago)")
continue
except WebSocketDisconnect:
logger.info(f"WebSocket client disconnected for job {job_id}")
break
except Exception as e:
logger.error(f"WebSocket error for job {job_id}: {e}")
if "Cannot call" in str(e) and "disconnect message" in str(e):
logger.error(f"FastAPI WebSocket disconnect error - connection already closed")
break
await asyncio.sleep(1)
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
except Exception as e:
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
finally:
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
connection_manager.disconnect(job_id, connection_id)
if consumer_task and not consumer_task.done():
if training_completed:
logger.info(f"Training completed, cancelling consumer for job {job_id}")
consumer_task.cancel()
else:
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
try:
await consumer_task
except asyncio.CancelledError:
logger.info(f"Consumer task cancelled for job {job_id}")
except Exception as e:
logger.error(f"Consumer task error for job {job_id}: {e}")
async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
"""Set up RabbitMQ consumer to listen for training events for a specific job"""
logger.info(f"Setting up RabbitMQ consumer for job {job_id}")
try:
# Create a unique queue for this WebSocket connection
queue_name = f"websocket_training_{job_id}_{tenant_id}"
async def handle_training_message(message):
"""Handle incoming RabbitMQ messages and forward to WebSocket"""
try:
# Parse the message
body = message.body.decode()
data = json.loads(body)
logger.debug(f"Received message for job {job_id}: {data.get('event_type', 'unknown')}")
# Extract event data
event_type = data.get("event_type", "unknown")
event_data = data.get("data", {})
# Only process messages for this specific job
message_job_id = event_data.get("job_id") if event_data else None
if message_job_id != job_id:
logger.debug(f"Ignoring message for different job: {message_job_id}")
await message.ack()
return
# Transform RabbitMQ message to WebSocket message format
websocket_message = {
"type": map_event_type_to_websocket_type(event_type),
"job_id": job_id,
"timestamp": data.get("timestamp"),
"data": event_data
}
logger.info(f"Forwarding {event_type} message to WebSocket clients for job {job_id}")
# Send to all WebSocket connections for this job
await connection_manager.send_to_job(job_id, websocket_message)
# Check if this is a completion message
if event_type in ["training.completed", "training.failed"]:
logger.info(f"Training completion detected for job {job_id}: {event_type}")
# Acknowledge the message
await message.ack()
logger.debug(f"Successfully processed {event_type} for job {job_id}")
except Exception as e:
logger.error(f"Error handling training message for job {job_id}: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
await message.nack(requeue=False)
# Check if training_publisher is connected
if not training_publisher.connected:
logger.warning(f"Training publisher not connected for job {job_id}, attempting to connect...")
success = await training_publisher.connect()
if not success:
logger.error(f"Failed to connect training_publisher for job {job_id}")
return
# Subscribe to training events
logger.info(f"Subscribing to training events for job {job_id}")
success = await training_publisher.consume_events(
exchange_name="training.events",
queue_name=queue_name,
routing_key="training.*",
callback=handle_training_message
)
if success:
logger.info(f"Successfully set up RabbitMQ consumer for job {job_id} (queue: {queue_name})")
# Keep the consumer running indefinitely until cancelled
try:
while True:
await asyncio.sleep(10)
logger.debug(f"Consumer heartbeat for job {job_id}")
except asyncio.CancelledError:
logger.info(f"Consumer cancelled for job {job_id}")
raise
except Exception as e:
logger.error(f"Consumer error for job {job_id}: {e}")
raise
else:
logger.error(f"Failed to set up RabbitMQ consumer for job {job_id}")
except Exception as e:
logger.error(f"Exception in setup_rabbitmq_consumer_for_job for job {job_id}: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str:
"""Map RabbitMQ event types to WebSocket message types"""
mapping = {
"training.started": "started",
"training.progress": "progress",
"training.completed": "completed",
"training.failed": "failed",
"training.cancelled": "cancelled",
"training.step.completed": "step_completed",
"training.product.started": "product_started",
"training.product.completed": "product_completed",
"training.product.failed": "product_failed",
"training.model.trained": "model_trained",
"training.data.validation.started": "validation_started",
"training.data.validation.completed": "validation_completed"
}
return mapping.get(rabbitmq_event_type, "unknown")
async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
"""Get current job status from database"""
try:
return {
"job_id": job_id,
"status": "running",
"progress": 0,
"current_step": "Starting...",
"started_at": "2025-07-30T19:00:00Z"
}
except Exception as e:
logger.error(f"Failed to get current job status: {e}")
return None
@router.get("/health")
async def health_check():
"""Health check endpoint for the training operations"""
return {
"status": "healthy",
"service": "training-operations",
"version": "2.0.0",
"features": [
"repository-pattern",
"dependency-injection",
"enhanced-error-handling",
"metrics-tracking",
"transactional-operations",
"websocket-support"
],
"timestamp": datetime.now().isoformat()
}