367 lines
13 KiB
Python
367 lines
13 KiB
Python
"""
|
|
Training Operations API - BUSINESS logic
|
|
Handles training job execution and metrics
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path
|
|
from typing import Optional, Dict, Any
|
|
import structlog
|
|
from datetime import datetime, timezone
|
|
import uuid
|
|
|
|
from shared.routing import RouteBuilder
|
|
from shared.monitoring.decorators import track_execution_time
|
|
from shared.monitoring.metrics import get_metrics_collector
|
|
from shared.database.base import create_database_manager
|
|
|
|
from app.services.training_service import EnhancedTrainingService
|
|
from app.schemas.training import (
|
|
TrainingJobRequest,
|
|
SingleProductTrainingRequest,
|
|
TrainingJobResponse
|
|
)
|
|
from app.services.training_events import (
|
|
publish_training_started,
|
|
publish_training_completed,
|
|
publish_training_failed
|
|
)
|
|
from app.core.config import settings
|
|
|
|
logger = structlog.get_logger()
|
|
route_builder = RouteBuilder('training')
|
|
|
|
router = APIRouter(tags=["training-operations"])
|
|
|
|
def get_enhanced_training_service():
|
|
"""Dependency injection for EnhancedTrainingService"""
|
|
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
|
return EnhancedTrainingService(database_manager)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_base_route("jobs"), response_model=TrainingJobResponse)
|
|
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
|
|
async def start_training_job(
|
|
request: TrainingJobRequest,
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
background_tasks: BackgroundTasks = BackgroundTasks(),
|
|
request_obj: Request = None,
|
|
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
|
):
|
|
"""
|
|
Start a new training job for all tenant products using repository pattern.
|
|
|
|
Enhanced immediate response pattern:
|
|
1. Validate request with enhanced validation
|
|
2. Create job record using repository pattern
|
|
3. Return 200 with enhanced job details
|
|
4. Execute enhanced training in background with repository tracking
|
|
|
|
Enhanced features:
|
|
- Repository pattern for data access
|
|
- Enhanced error handling and logging
|
|
- Metrics tracking and monitoring
|
|
- Transactional operations
|
|
"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
# Generate enhanced job ID
|
|
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
|
|
|
logger.info("Creating enhanced training job using repository pattern",
|
|
job_id=job_id,
|
|
tenant_id=tenant_id)
|
|
|
|
# Record job creation metrics
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_training_jobs_created_total")
|
|
|
|
# Publish training.started event immediately so WebSocket clients
|
|
# have initial state when they connect
|
|
await publish_training_started(
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
total_products=0 # Will be updated when actual training starts
|
|
)
|
|
|
|
# Add enhanced background task
|
|
background_tasks.add_task(
|
|
execute_training_job_background,
|
|
tenant_id=tenant_id,
|
|
job_id=job_id,
|
|
bakery_location=(40.4168, -3.7038),
|
|
requested_start=request.start_date,
|
|
requested_end=request.end_date
|
|
)
|
|
|
|
# Return enhanced immediate success response
|
|
response_data = {
|
|
"job_id": job_id,
|
|
"tenant_id": tenant_id,
|
|
"status": "pending",
|
|
"message": "Enhanced training job started successfully using repository pattern",
|
|
"created_at": datetime.now(timezone.utc),
|
|
"estimated_duration_minutes": 18,
|
|
"training_results": {
|
|
"total_products": 0,
|
|
"successful_trainings": 0,
|
|
"failed_trainings": 0,
|
|
"products": [],
|
|
"overall_training_time_seconds": 0.0
|
|
},
|
|
"data_summary": None,
|
|
"completed_at": None,
|
|
"error_details": None,
|
|
"processing_metadata": {
|
|
"background_task": True,
|
|
"async_execution": True,
|
|
"enhanced_features": True,
|
|
"repository_pattern": True,
|
|
"dependency_injection": True
|
|
}
|
|
}
|
|
|
|
logger.info("Enhanced training job queued successfully",
|
|
job_id=job_id,
|
|
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
|
|
|
|
return TrainingJobResponse(**response_data)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except ValueError as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_training_validation_errors_total")
|
|
logger.error("Enhanced training job validation error",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_training_job_errors_total")
|
|
logger.error("Failed to queue enhanced training job",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to start enhanced training job"
|
|
)
|
|
|
|
|
|
async def execute_training_job_background(
|
|
tenant_id: str,
|
|
job_id: str,
|
|
bakery_location: tuple,
|
|
requested_start: Optional[datetime] = None,
|
|
requested_end: Optional[datetime] = None
|
|
):
|
|
"""
|
|
Enhanced background task that executes the training job using repository pattern.
|
|
|
|
Enhanced features:
|
|
- Repository pattern for all data operations
|
|
- Enhanced error handling with structured logging
|
|
- Transactional operations for data consistency
|
|
- Comprehensive metrics tracking
|
|
- Database connection pooling
|
|
- Enhanced progress reporting
|
|
"""
|
|
|
|
logger.info("Enhanced background training job started",
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
features=["repository-pattern", "enhanced-tracking"])
|
|
|
|
# Get enhanced training service with dependency injection
|
|
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
|
enhanced_training_service = EnhancedTrainingService(database_manager)
|
|
|
|
try:
|
|
# Create initial training log entry first
|
|
await enhanced_training_service._update_job_status_repository(
|
|
job_id=job_id,
|
|
status="pending",
|
|
progress=0,
|
|
current_step="Starting enhanced training job",
|
|
tenant_id=tenant_id
|
|
)
|
|
|
|
# This will be published by the training service itself
|
|
# when it starts execution
|
|
|
|
training_config = {
|
|
"job_id": job_id,
|
|
"tenant_id": tenant_id,
|
|
"bakery_location": {
|
|
"latitude": bakery_location[0],
|
|
"longitude": bakery_location[1]
|
|
},
|
|
"requested_start": requested_start.isoformat() if requested_start else None,
|
|
"requested_end": requested_end.isoformat() if requested_end else None,
|
|
"estimated_duration_minutes": 18,
|
|
"background_execution": True,
|
|
"enhanced_features": True,
|
|
"repository_pattern": True,
|
|
"api_version": "enhanced_v1"
|
|
}
|
|
|
|
# Update job status using repository pattern
|
|
await enhanced_training_service._update_job_status_repository(
|
|
job_id=job_id,
|
|
status="running",
|
|
progress=0,
|
|
current_step="Initializing enhanced training pipeline",
|
|
tenant_id=tenant_id
|
|
)
|
|
|
|
# Execute the enhanced training pipeline with repository pattern
|
|
result = await enhanced_training_service.start_training_job(
|
|
tenant_id=tenant_id,
|
|
job_id=job_id,
|
|
bakery_location=bakery_location,
|
|
requested_start=requested_start,
|
|
requested_end=requested_end
|
|
)
|
|
|
|
# Update final status using repository pattern
|
|
await enhanced_training_service._update_job_status_repository(
|
|
job_id=job_id,
|
|
status="completed",
|
|
progress=100,
|
|
current_step="Enhanced training completed successfully",
|
|
results=result,
|
|
tenant_id=tenant_id
|
|
)
|
|
|
|
# Completion event is published by the training service
|
|
|
|
logger.info("Enhanced background training job completed successfully",
|
|
job_id=job_id,
|
|
models_created=result.get('products_trained', 0),
|
|
features=["repository-pattern", "enhanced-tracking"])
|
|
|
|
except Exception as training_error:
|
|
logger.error("Enhanced training pipeline failed",
|
|
job_id=job_id,
|
|
error=str(training_error))
|
|
|
|
try:
|
|
await enhanced_training_service._update_job_status_repository(
|
|
job_id=job_id,
|
|
status="failed",
|
|
progress=0,
|
|
current_step="Enhanced training failed",
|
|
error_message=str(training_error),
|
|
tenant_id=tenant_id
|
|
)
|
|
except Exception as status_error:
|
|
logger.error("Failed to update job status after training error",
|
|
job_id=job_id,
|
|
status_error=str(status_error))
|
|
|
|
# Failure event is published by the training service
|
|
await publish_training_failed(job_id, tenant_id, str(training_error))
|
|
|
|
except Exception as background_error:
|
|
logger.error("Critical error in enhanced background training job",
|
|
job_id=job_id,
|
|
error=str(background_error))
|
|
|
|
finally:
|
|
logger.info("Enhanced background training job cleanup completed",
|
|
job_id=job_id)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse)
|
|
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
|
|
async def start_single_product_training(
|
|
request: SingleProductTrainingRequest,
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
inventory_product_id: str = Path(..., description="Inventory product UUID"),
|
|
request_obj: Request = None,
|
|
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
|
):
|
|
"""
|
|
Start enhanced training for a single product using repository pattern.
|
|
|
|
Enhanced features:
|
|
- Repository pattern for data access
|
|
- Enhanced error handling and validation
|
|
- Metrics tracking
|
|
- Transactional operations
|
|
"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
logger.info("Starting enhanced single product training",
|
|
inventory_product_id=inventory_product_id,
|
|
tenant_id=tenant_id)
|
|
|
|
# Record metrics
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_single_product_training_total")
|
|
|
|
# Generate enhanced job ID
|
|
job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
|
|
|
|
# Delegate to enhanced training service
|
|
result = await enhanced_training_service.start_single_product_training(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
job_id=job_id,
|
|
bakery_location=request.bakery_location or (40.4168, -3.7038)
|
|
)
|
|
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_single_product_training_success_total")
|
|
|
|
logger.info("Enhanced single product training completed",
|
|
inventory_product_id=inventory_product_id,
|
|
job_id=job_id)
|
|
|
|
return TrainingJobResponse(**result)
|
|
|
|
except ValueError as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_single_product_validation_errors_total")
|
|
logger.error("Enhanced single product training validation error",
|
|
error=str(e),
|
|
inventory_product_id=inventory_product_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_single_product_training_errors_total")
|
|
logger.error("Enhanced single product training failed",
|
|
error=str(e),
|
|
inventory_product_id=inventory_product_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Enhanced single product training failed"
|
|
)
|
|
|
|
|
|
@router.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint for the training operations"""
|
|
return {
|
|
"status": "healthy",
|
|
"service": "training-operations",
|
|
"version": "3.0.0",
|
|
"features": [
|
|
"repository-pattern",
|
|
"dependency-injection",
|
|
"enhanced-error-handling",
|
|
"metrics-tracking",
|
|
"transactional-operations"
|
|
],
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|