Files
bakery-ia/services/training/app/api/training_operations.py
2025-10-29 06:58:05 +01:00

506 lines
19 KiB
Python

"""
Training Operations API - BUSINESS logic
Handles training job execution and metrics
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path
from typing import Optional, Dict, Any
import structlog
from datetime import datetime, timezone
import uuid
import shared.redis_utils
from sqlalchemy.ext.asyncio import AsyncSession
from shared.routing import RouteBuilder
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from shared.database.base import create_database_manager
from shared.auth.decorators import get_current_user_dep
from shared.auth.access_control import require_user_role, admin_role_required
from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction
from shared.subscription.plans import (
get_training_job_quota,
get_dataset_size_limit
)
from app.services.training_service import EnhancedTrainingService
from app.schemas.training import (
TrainingJobRequest,
SingleProductTrainingRequest,
TrainingJobResponse
)
from app.utils.time_estimation import (
calculate_initial_estimate,
calculate_estimated_completion_time,
get_historical_average_estimate
)
from app.services.training_events import (
publish_training_started,
publish_training_completed,
publish_training_failed
)
from app.core.config import settings
from app.core.database import get_db
from app.models import AuditLog
logger = structlog.get_logger()
route_builder = RouteBuilder('training')
router = APIRouter(tags=["training-operations"])
# Initialize audit logger
audit_logger = create_audit_logger("training-service", AuditLog)
# Redis client for rate limiting
_redis_client = None
async def get_training_redis_client():
"""Get or create Redis client for rate limiting"""
global _redis_client
if _redis_client is None:
# Initialize Redis if not already done
try:
from app.core.config import settings
_redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
except:
# Fallback to getting the client directly (if already initialized elsewhere)
_redis_client = await shared.redis_utils.get_redis_client()
return _redis_client
async def get_rate_limiter():
"""Dependency for rate limiter"""
redis_client = await get_training_redis_client()
return create_rate_limiter(redis_client)
def get_enhanced_training_service():
"""Dependency injection for EnhancedTrainingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
return EnhancedTrainingService(database_manager)
@router.post(
route_builder.build_base_route("jobs"), response_model=TrainingJobResponse)
@require_user_role(['admin', 'owner'])
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
async def start_training_job(
request: TrainingJobRequest,
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
request_obj: Request = None,
current_user: Dict[str, Any] = Depends(get_current_user_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service),
rate_limiter = Depends(get_rate_limiter),
db: AsyncSession = Depends(get_db)
):
"""
Start a new training job for all tenant products (Admin+ only, quota enforced).
**RBAC:** Admin or Owner role required
**Quotas:**
- Starter: 1 training job/day, max 1,000 rows
- Professional: 5 training jobs/day, max 10,000 rows
- Enterprise: Unlimited jobs, unlimited rows
Enhanced immediate response pattern:
1. Validate subscription tier and quotas
2. Validate request with enhanced validation
3. Create job record using repository pattern
4. Return 200 with enhanced job details
5. Execute enhanced training in background with repository tracking
Enhanced features:
- Repository pattern for data access
- Quota enforcement by subscription tier
- Audit logging for all operations
- Enhanced error handling and logging
- Metrics tracking and monitoring
- Transactional operations
"""
metrics = get_metrics_collector(request_obj)
# Get subscription tier and enforce quotas
tier = current_user.get('subscription_tier', 'starter')
# Estimate dataset size (this should come from the request or be calculated)
# For now, we'll assume a reasonable estimate
estimated_dataset_size = request.estimated_rows if hasattr(request, 'estimated_rows') else 500
# Initialize variables for later use
quota_result = None
quota_limit = None
try:
# Validate dataset size limits
await rate_limiter.validate_dataset_size(
tenant_id, estimated_dataset_size, tier
)
# Check daily training job quota
quota_limit = get_training_job_quota(tier)
quota_result = await rate_limiter.check_and_increment_quota(
tenant_id,
"training_jobs",
quota_limit,
period=86400 # 24 hours
)
logger.info("Training job quota check passed",
tenant_id=tenant_id,
tier=tier,
current_usage=quota_result.get('current', 0) if quota_result else 0,
limit=quota_limit)
except HTTPException:
# Quota or validation error - re-raise
raise
except Exception as quota_error:
logger.error("Quota validation failed", error=str(quota_error))
# Continue with job creation but log the error
try:
# Generate enhanced job ID
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info("Creating enhanced training job using repository pattern",
job_id=job_id,
tenant_id=tenant_id)
# Record job creation metrics
if metrics:
metrics.increment_counter("enhanced_training_jobs_created_total")
# Calculate intelligent time estimate
# We don't know exact product count yet, so use historical average or estimate
try:
# Try to get historical average for this tenant
historical_avg = await get_historical_average_estimate(db, tenant_id)
# If no historical data, estimate based on typical product count (10-20 products)
estimated_products = 15 # Conservative estimate
estimated_duration_minutes = calculate_initial_estimate(
total_products=estimated_products,
avg_training_time_per_product=historical_avg if historical_avg else 60.0
)
except Exception as est_error:
logger.warning("Could not calculate intelligent estimate, using default",
error=str(est_error))
estimated_duration_minutes = 15 # Default fallback
# Calculate estimated completion time
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
# Publish training.started event immediately so WebSocket clients
# have initial state when they connect
await publish_training_started(
job_id=job_id,
tenant_id=tenant_id,
total_products=0, # Will be updated when actual training starts
estimated_duration_minutes=estimated_duration_minutes,
estimated_completion_time=estimated_completion_time.isoformat()
)
# Add enhanced background task
background_tasks.add_task(
execute_training_job_background,
tenant_id=tenant_id,
job_id=job_id,
bakery_location=(40.4168, -3.7038),
requested_start=request.start_date,
requested_end=request.end_date,
estimated_duration_minutes=estimated_duration_minutes
)
# Return enhanced immediate success response
response_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending",
"message": "Enhanced training job started successfully using repository pattern",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": estimated_duration_minutes,
"training_results": {
"total_products": 0,
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
"overall_training_time_seconds": 0.0
},
"data_summary": None,
"completed_at": None,
"error_details": None,
"processing_metadata": {
"background_task": True,
"async_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"dependency_injection": True
}
}
logger.info("Enhanced training job queued successfully",
job_id=job_id,
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
# Log audit event for training job creation
try:
from app.core.database import get_db
db = next(get_db())
await audit_logger.log_event(
db_session=db,
tenant_id=tenant_id,
user_id=current_user["user_id"],
action=AuditAction.CREATE.value,
resource_type="training_job",
resource_id=job_id,
severity=AuditSeverity.MEDIUM.value,
description=f"Started training job (tier: {tier})",
metadata={
"job_id": job_id,
"tier": tier,
"estimated_dataset_size": estimated_dataset_size,
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
"quota_limit": quota_limit if quota_limit else "unlimited"
},
endpoint="/jobs",
method="POST"
)
except Exception as audit_error:
logger.warning("Failed to log audit event", error=str(audit_error))
return TrainingJobResponse(**response_data)
except HTTPException:
raise
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_training_validation_errors_total")
logger.error("Enhanced training job validation error",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_training_job_errors_total")
logger.error("Failed to queue enhanced training job",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to start enhanced training job"
)
async def execute_training_job_background(
tenant_id: str,
job_id: str,
bakery_location: tuple,
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None,
estimated_duration_minutes: int = 15
):
"""
Enhanced background task that executes the training job using repository pattern.
Enhanced features:
- Repository pattern for all data operations
- Enhanced error handling with structured logging
- Transactional operations for data consistency
- Comprehensive metrics tracking
- Database connection pooling
- Enhanced progress reporting
"""
logger.info("Enhanced background training job started",
job_id=job_id,
tenant_id=tenant_id,
features=["repository-pattern", "enhanced-tracking"])
# Get enhanced training service with dependency injection
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
enhanced_training_service = EnhancedTrainingService(database_manager)
try:
# Create initial training log entry first
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="pending",
progress=0,
current_step="Starting enhanced training job",
tenant_id=tenant_id
)
# This will be published by the training service itself
# when it starts execution
training_config = {
"job_id": job_id,
"tenant_id": tenant_id,
"bakery_location": {
"latitude": bakery_location[0],
"longitude": bakery_location[1]
},
"requested_start": requested_start.isoformat() if requested_start else None,
"requested_end": requested_end.isoformat() if requested_end else None,
"estimated_duration_minutes": estimated_duration_minutes,
"background_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"api_version": "enhanced_v1"
}
# Update job status using repository pattern
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Initializing enhanced training pipeline",
tenant_id=tenant_id
)
# Execute the enhanced training pipeline with repository pattern
result = await enhanced_training_service.start_training_job(
tenant_id=tenant_id,
job_id=job_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end
)
# Note: Final status is already updated by start_training_job() via complete_training_log()
# No need for redundant update here - it was causing duplicate log entries
# Completion event is published by the training service
logger.info("Enhanced background training job completed successfully",
job_id=job_id,
models_created=result.get('products_trained', 0),
features=["repository-pattern", "enhanced-tracking"])
except Exception as training_error:
logger.error("Enhanced training pipeline failed",
job_id=job_id,
error=str(training_error))
try:
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Enhanced training failed",
error_message=str(training_error),
tenant_id=tenant_id
)
except Exception as status_error:
logger.error("Failed to update job status after training error",
job_id=job_id,
status_error=str(status_error))
# Failure event is published by the training service
await publish_training_failed(job_id, tenant_id, str(training_error))
except Exception as background_error:
logger.error("Critical error in enhanced background training job",
job_id=job_id,
error=str(background_error))
finally:
logger.info("Enhanced background training job cleanup completed",
job_id=job_id)
@router.post(
route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse)
@require_user_role(['admin', 'owner'])
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
async def start_single_product_training(
request: SingleProductTrainingRequest,
tenant_id: str = Path(..., description="Tenant ID"),
inventory_product_id: str = Path(..., description="Inventory product UUID"),
request_obj: Request = None,
current_user: Dict[str, Any] = Depends(get_current_user_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start enhanced training for a single product (Admin+ only).
**RBAC:** Admin or Owner role required
Enhanced features:
- Repository pattern for data access
- Enhanced error handling and validation
- Metrics tracking
- Transactional operations
"""
metrics = get_metrics_collector(request_obj)
try:
logger.info("Starting enhanced single product training",
inventory_product_id=inventory_product_id,
tenant_id=tenant_id)
# Record metrics
if metrics:
metrics.increment_counter("enhanced_single_product_training_total")
# Generate enhanced job ID
job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
# Delegate to enhanced training service
result = await enhanced_training_service.start_single_product_training(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id,
bakery_location=request.bakery_location or (40.4168, -3.7038)
)
if metrics:
metrics.increment_counter("enhanced_single_product_training_success_total")
logger.info("Enhanced single product training completed",
inventory_product_id=inventory_product_id,
job_id=job_id)
return TrainingJobResponse(**result)
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_single_product_validation_errors_total")
logger.error("Enhanced single product training validation error",
error=str(e),
inventory_product_id=inventory_product_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_single_product_training_errors_total")
logger.error("Enhanced single product training failed",
error=str(e),
inventory_product_id=inventory_product_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Enhanced single product training failed"
)
@router.get("/health")
async def health_check():
"""Health check endpoint for the training operations"""
return {
"status": "healthy",
"service": "training-operations",
"version": "3.0.0",
"features": [
"repository-pattern",
"dependency-injection",
"enhanced-error-handling",
"metrics-tracking",
"transactional-operations"
],
"timestamp": datetime.now().isoformat()
}