""" Training Operations API - BUSINESS logic Handles training job execution and metrics """ from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path from typing import Optional, Dict, Any import structlog from datetime import datetime, timezone import uuid import shared.redis_utils from shared.routing import RouteBuilder from shared.monitoring.decorators import track_execution_time from shared.monitoring.metrics import get_metrics_collector from shared.database.base import create_database_manager from shared.auth.decorators import get_current_user_dep from shared.auth.access_control import require_user_role, admin_role_required from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction from shared.subscription.plans import ( get_training_job_quota, get_dataset_size_limit ) from app.services.training_service import EnhancedTrainingService from app.schemas.training import ( TrainingJobRequest, SingleProductTrainingRequest, TrainingJobResponse ) from app.utils.time_estimation import ( calculate_initial_estimate, calculate_estimated_completion_time, get_historical_average_estimate ) from app.services.training_events import ( publish_training_started, publish_training_completed, publish_training_failed ) from app.core.config import settings logger = structlog.get_logger() route_builder = RouteBuilder('training') router = APIRouter(tags=["training-operations"]) # Initialize audit logger audit_logger = create_audit_logger("training-service") # Redis client for rate limiting _redis_client = None async def get_training_redis_client(): """Get or create Redis client for rate limiting""" global _redis_client if _redis_client is None: # Initialize Redis if not already done try: from app.core.config import settings _redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL) except: # Fallback to getting the client directly (if already initialized elsewhere) _redis_client = await shared.redis_utils.get_redis_client() return _redis_client async def get_rate_limiter(): """Dependency for rate limiter""" redis_client = await get_training_redis_client() return create_rate_limiter(redis_client) def get_enhanced_training_service(): """Dependency injection for EnhancedTrainingService""" database_manager = create_database_manager(settings.DATABASE_URL, "training-service") return EnhancedTrainingService(database_manager) @router.post( route_builder.build_base_route("jobs"), response_model=TrainingJobResponse) @require_user_role(['admin', 'owner']) @track_execution_time("enhanced_training_job_duration_seconds", "training-service") async def start_training_job( request: TrainingJobRequest, tenant_id: str = Path(..., description="Tenant ID"), background_tasks: BackgroundTasks = BackgroundTasks(), request_obj: Request = None, current_user: Dict[str, Any] = Depends(get_current_user_dep), enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service), rate_limiter = Depends(get_rate_limiter) ): """ Start a new training job for all tenant products (Admin+ only, quota enforced). **RBAC:** Admin or Owner role required **Quotas:** - Starter: 1 training job/day, max 1,000 rows - Professional: 5 training jobs/day, max 10,000 rows - Enterprise: Unlimited jobs, unlimited rows Enhanced immediate response pattern: 1. Validate subscription tier and quotas 2. Validate request with enhanced validation 3. Create job record using repository pattern 4. Return 200 with enhanced job details 5. Execute enhanced training in background with repository tracking Enhanced features: - Repository pattern for data access - Quota enforcement by subscription tier - Audit logging for all operations - Enhanced error handling and logging - Metrics tracking and monitoring - Transactional operations """ metrics = get_metrics_collector(request_obj) # Get subscription tier and enforce quotas tier = current_user.get('subscription_tier', 'starter') # Estimate dataset size (this should come from the request or be calculated) # For now, we'll assume a reasonable estimate estimated_dataset_size = request.estimated_rows if hasattr(request, 'estimated_rows') else 500 # Initialize variables for later use quota_result = None quota_limit = None try: # Validate dataset size limits await rate_limiter.validate_dataset_size( tenant_id, estimated_dataset_size, tier ) # Check daily training job quota quota_limit = get_training_job_quota(tier) quota_result = await rate_limiter.check_and_increment_quota( tenant_id, "training_jobs", quota_limit, period=86400 # 24 hours ) logger.info("Training job quota check passed", tenant_id=tenant_id, tier=tier, current_usage=quota_result.get('current', 0) if quota_result else 0, limit=quota_limit) except HTTPException: # Quota or validation error - re-raise raise except Exception as quota_error: logger.error("Quota validation failed", error=str(quota_error)) # Continue with job creation but log the error try: # Generate enhanced job ID job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}" logger.info("Creating enhanced training job using repository pattern", job_id=job_id, tenant_id=tenant_id) # Record job creation metrics if metrics: metrics.increment_counter("enhanced_training_jobs_created_total") # Calculate intelligent time estimate # We don't know exact product count yet, so use historical average or estimate try: # Try to get historical average for this tenant from app.core.database import get_db db = next(get_db()) historical_avg = get_historical_average_estimate(db, tenant_id) # If no historical data, estimate based on typical product count (10-20 products) estimated_products = 15 # Conservative estimate estimated_duration_minutes = calculate_initial_estimate( total_products=estimated_products, avg_training_time_per_product=historical_avg if historical_avg else 60.0 ) except Exception as est_error: logger.warning("Could not calculate intelligent estimate, using default", error=str(est_error)) estimated_duration_minutes = 15 # Default fallback # Calculate estimated completion time estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes) # Publish training.started event immediately so WebSocket clients # have initial state when they connect await publish_training_started( job_id=job_id, tenant_id=tenant_id, total_products=0, # Will be updated when actual training starts estimated_duration_minutes=estimated_duration_minutes, estimated_completion_time=estimated_completion_time.isoformat() ) # Add enhanced background task background_tasks.add_task( execute_training_job_background, tenant_id=tenant_id, job_id=job_id, bakery_location=(40.4168, -3.7038), requested_start=request.start_date, requested_end=request.end_date, estimated_duration_minutes=estimated_duration_minutes ) # Return enhanced immediate success response response_data = { "job_id": job_id, "tenant_id": tenant_id, "status": "pending", "message": "Enhanced training job started successfully using repository pattern", "created_at": datetime.now(timezone.utc), "estimated_duration_minutes": estimated_duration_minutes, "training_results": { "total_products": 0, "successful_trainings": 0, "failed_trainings": 0, "products": [], "overall_training_time_seconds": 0.0 }, "data_summary": None, "completed_at": None, "error_details": None, "processing_metadata": { "background_task": True, "async_execution": True, "enhanced_features": True, "repository_pattern": True, "dependency_injection": True } } logger.info("Enhanced training job queued successfully", job_id=job_id, features=["repository-pattern", "dependency-injection", "enhanced-tracking"]) # Log audit event for training job creation try: from app.core.database import get_db db = next(get_db()) await audit_logger.log_event( db_session=db, tenant_id=tenant_id, user_id=current_user["user_id"], action=AuditAction.CREATE.value, resource_type="training_job", resource_id=job_id, severity=AuditSeverity.MEDIUM.value, description=f"Started training job (tier: {tier})", metadata={ "job_id": job_id, "tier": tier, "estimated_dataset_size": estimated_dataset_size, "quota_usage": quota_result.get('current', 0) if quota_result else 0, "quota_limit": quota_limit if quota_limit else "unlimited" }, endpoint="/jobs", method="POST" ) except Exception as audit_error: logger.warning("Failed to log audit event", error=str(audit_error)) return TrainingJobResponse(**response_data) except HTTPException: raise except ValueError as e: if metrics: metrics.increment_counter("enhanced_training_validation_errors_total") logger.error("Enhanced training job validation error", error=str(e), tenant_id=tenant_id) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) ) except Exception as e: if metrics: metrics.increment_counter("enhanced_training_job_errors_total") logger.error("Failed to queue enhanced training job", error=str(e), tenant_id=tenant_id) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to start enhanced training job" ) async def execute_training_job_background( tenant_id: str, job_id: str, bakery_location: tuple, requested_start: Optional[datetime] = None, requested_end: Optional[datetime] = None, estimated_duration_minutes: int = 15 ): """ Enhanced background task that executes the training job using repository pattern. Enhanced features: - Repository pattern for all data operations - Enhanced error handling with structured logging - Transactional operations for data consistency - Comprehensive metrics tracking - Database connection pooling - Enhanced progress reporting """ logger.info("Enhanced background training job started", job_id=job_id, tenant_id=tenant_id, features=["repository-pattern", "enhanced-tracking"]) # Get enhanced training service with dependency injection database_manager = create_database_manager(settings.DATABASE_URL, "training-service") enhanced_training_service = EnhancedTrainingService(database_manager) try: # Create initial training log entry first await enhanced_training_service._update_job_status_repository( job_id=job_id, status="pending", progress=0, current_step="Starting enhanced training job", tenant_id=tenant_id ) # This will be published by the training service itself # when it starts execution training_config = { "job_id": job_id, "tenant_id": tenant_id, "bakery_location": { "latitude": bakery_location[0], "longitude": bakery_location[1] }, "requested_start": requested_start.isoformat() if requested_start else None, "requested_end": requested_end.isoformat() if requested_end else None, "estimated_duration_minutes": estimated_duration_minutes, "background_execution": True, "enhanced_features": True, "repository_pattern": True, "api_version": "enhanced_v1" } # Update job status using repository pattern await enhanced_training_service._update_job_status_repository( job_id=job_id, status="running", progress=0, current_step="Initializing enhanced training pipeline", tenant_id=tenant_id ) # Execute the enhanced training pipeline with repository pattern result = await enhanced_training_service.start_training_job( tenant_id=tenant_id, job_id=job_id, bakery_location=bakery_location, requested_start=requested_start, requested_end=requested_end ) # Note: Final status is already updated by start_training_job() via complete_training_log() # No need for redundant update here - it was causing duplicate log entries # Completion event is published by the training service logger.info("Enhanced background training job completed successfully", job_id=job_id, models_created=result.get('products_trained', 0), features=["repository-pattern", "enhanced-tracking"]) except Exception as training_error: logger.error("Enhanced training pipeline failed", job_id=job_id, error=str(training_error)) try: await enhanced_training_service._update_job_status_repository( job_id=job_id, status="failed", progress=0, current_step="Enhanced training failed", error_message=str(training_error), tenant_id=tenant_id ) except Exception as status_error: logger.error("Failed to update job status after training error", job_id=job_id, status_error=str(status_error)) # Failure event is published by the training service await publish_training_failed(job_id, tenant_id, str(training_error)) except Exception as background_error: logger.error("Critical error in enhanced background training job", job_id=job_id, error=str(background_error)) finally: logger.info("Enhanced background training job cleanup completed", job_id=job_id) @router.post( route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse) @require_user_role(['admin', 'owner']) @track_execution_time("enhanced_single_product_training_duration_seconds", "training-service") async def start_single_product_training( request: SingleProductTrainingRequest, tenant_id: str = Path(..., description="Tenant ID"), inventory_product_id: str = Path(..., description="Inventory product UUID"), request_obj: Request = None, current_user: Dict[str, Any] = Depends(get_current_user_dep), enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service) ): """ Start enhanced training for a single product (Admin+ only). **RBAC:** Admin or Owner role required Enhanced features: - Repository pattern for data access - Enhanced error handling and validation - Metrics tracking - Transactional operations """ metrics = get_metrics_collector(request_obj) try: logger.info("Starting enhanced single product training", inventory_product_id=inventory_product_id, tenant_id=tenant_id) # Record metrics if metrics: metrics.increment_counter("enhanced_single_product_training_total") # Generate enhanced job ID job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}" # Delegate to enhanced training service result = await enhanced_training_service.start_single_product_training( tenant_id=tenant_id, inventory_product_id=inventory_product_id, job_id=job_id, bakery_location=request.bakery_location or (40.4168, -3.7038) ) if metrics: metrics.increment_counter("enhanced_single_product_training_success_total") logger.info("Enhanced single product training completed", inventory_product_id=inventory_product_id, job_id=job_id) return TrainingJobResponse(**result) except ValueError as e: if metrics: metrics.increment_counter("enhanced_single_product_validation_errors_total") logger.error("Enhanced single product training validation error", error=str(e), inventory_product_id=inventory_product_id) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) ) except Exception as e: if metrics: metrics.increment_counter("enhanced_single_product_training_errors_total") logger.error("Enhanced single product training failed", error=str(e), inventory_product_id=inventory_product_id) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Enhanced single product training failed" ) @router.get("/health") async def health_check(): """Health check endpoint for the training operations""" return { "status": "healthy", "service": "training-operations", "version": "3.0.0", "features": [ "repository-pattern", "dependency-injection", "enhanced-error-handling", "metrics-tracking", "transactional-operations" ], "timestamp": datetime.now().isoformat() }