""" Training Operations API - BUSINESS logic Handles training job execution and metrics """ from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path from typing import Optional, Dict, Any import structlog from datetime import datetime, timezone import uuid import shared.redis_utils from sqlalchemy.ext.asyncio import AsyncSession from shared.routing import RouteBuilder from shared.monitoring.decorators import track_execution_time from shared.monitoring.metrics import get_metrics_collector from shared.database.base import create_database_manager from shared.auth.decorators import get_current_user_dep from shared.auth.access_control import require_user_role, admin_role_required, service_only_access from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction from shared.subscription.plans import ( get_training_job_quota, get_dataset_size_limit ) from app.services.training_service import EnhancedTrainingService from app.schemas.training import ( TrainingJobRequest, SingleProductTrainingRequest, TrainingJobResponse ) from app.utils.time_estimation import ( calculate_initial_estimate, calculate_estimated_completion_time, get_historical_average_estimate ) from app.services.training_events import ( publish_training_started, publish_training_completed, publish_training_failed ) from app.core.config import settings from app.core.database import get_db from app.models import AuditLog logger = structlog.get_logger() route_builder = RouteBuilder('training') router = APIRouter(tags=["training-operations"]) # Initialize audit logger audit_logger = create_audit_logger("training-service", AuditLog) # Redis client for rate limiting _redis_client = None async def get_training_redis_client(): """Get or create Redis client for rate limiting""" global _redis_client if _redis_client is None: # Initialize Redis if not already done try: from app.core.config import settings _redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL) except: # Fallback to getting the client directly (if already initialized elsewhere) _redis_client = await shared.redis_utils.get_redis_client() return _redis_client async def get_rate_limiter(): """Dependency for rate limiter""" redis_client = await get_training_redis_client() return create_rate_limiter(redis_client) def get_enhanced_training_service(): """Dependency injection for EnhancedTrainingService""" database_manager = create_database_manager(settings.DATABASE_URL, "training-service") return EnhancedTrainingService(database_manager) @router.post( route_builder.build_base_route("jobs"), response_model=TrainingJobResponse) @require_user_role(['admin', 'owner']) @track_execution_time("enhanced_training_job_duration_seconds", "training-service") async def start_training_job( request: TrainingJobRequest, tenant_id: str = Path(..., description="Tenant ID"), background_tasks: BackgroundTasks = BackgroundTasks(), request_obj: Request = None, current_user: Dict[str, Any] = Depends(get_current_user_dep), enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service), rate_limiter = Depends(get_rate_limiter), db: AsyncSession = Depends(get_db) ): """ Start a new training job for all tenant products (Admin+ only, quota enforced). **RBAC:** Admin or Owner role required **Quotas:** - Starter: 1 training job/day, max 1,000 rows - Professional: 5 training jobs/day, max 10,000 rows - Enterprise: Unlimited jobs, unlimited rows Enhanced immediate response pattern: 1. Validate subscription tier and quotas 2. Validate request with enhanced validation 3. Create job record using repository pattern 4. Return 200 with enhanced job details 5. Execute enhanced training in background with repository tracking Enhanced features: - Repository pattern for data access - Quota enforcement by subscription tier - Audit logging for all operations - Enhanced error handling and logging - Metrics tracking and monitoring - Transactional operations """ metrics = get_metrics_collector(request_obj) # Get subscription tier and enforce quotas tier = current_user.get('subscription_tier', 'starter') # Estimate dataset size (this should come from the request or be calculated) # For now, we'll assume a reasonable estimate estimated_dataset_size = request.estimated_rows if hasattr(request, 'estimated_rows') else 500 # Initialize variables for later use quota_result = None quota_limit = None try: # Validate dataset size limits await rate_limiter.validate_dataset_size( tenant_id, estimated_dataset_size, tier ) # Check daily training job quota quota_limit = get_training_job_quota(tier) quota_result = await rate_limiter.check_and_increment_quota( tenant_id, "training_jobs", quota_limit, period=86400 # 24 hours ) logger.info("Training job quota check passed", tenant_id=tenant_id, tier=tier, current_usage=quota_result.get('current', 0) if quota_result else 0, limit=quota_limit) except HTTPException: # Quota or validation error - re-raise raise except Exception as quota_error: logger.error("Quota validation failed", error=str(quota_error)) # Continue with job creation but log the error try: # CRITICAL FIX: Check for existing running jobs before starting new one # This prevents duplicate tenant-level training jobs async with enhanced_training_service.database_manager.get_session() as check_session: from app.repositories.training_log_repository import TrainingLogRepository log_repo = TrainingLogRepository(check_session) # Check for active jobs (running or pending) active_jobs = await log_repo.get_active_jobs(tenant_id=tenant_id) pending_jobs = await log_repo.get_logs_by_tenant( tenant_id=tenant_id, status="pending", limit=10 ) all_active = active_jobs + pending_jobs if all_active: # Training job already in progress, return existing job info existing_job = all_active[0] logger.info("Training job already in progress, returning existing job", existing_job_id=existing_job.job_id, tenant_id=tenant_id, status=existing_job.status) return TrainingJobResponse( job_id=existing_job.job_id, tenant_id=tenant_id, status=existing_job.status, message=f"Training job already in progress (started {existing_job.created_at.isoformat() if existing_job.created_at else 'recently'})", created_at=existing_job.created_at or datetime.now(timezone.utc), estimated_duration_minutes=existing_job.config.get("estimated_duration_minutes", 15) if existing_job.config else 15, training_results={ "total_products": 0, "successful_trainings": 0, "failed_trainings": 0, "products": [], "overall_training_time_seconds": 0.0 }, data_summary=None, completed_at=None, error_details=None, processing_metadata={ "background_task": True, "async_execution": True, "existing_job": True, "deduplication": True } ) # No existing job, proceed with creating new one # Generate enhanced job ID job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}" logger.info("Creating enhanced training job using repository pattern", job_id=job_id, tenant_id=tenant_id) # Record job creation metrics if metrics: metrics.increment_counter("enhanced_training_jobs_created_total") # Calculate intelligent time estimate # We don't know exact product count yet, so use historical average or estimate try: # Try to get historical average for this tenant historical_avg = await get_historical_average_estimate(db, tenant_id) # If no historical data, estimate based on typical product count (10-20 products) estimated_products = 15 # Conservative estimate estimated_duration_minutes = calculate_initial_estimate( total_products=estimated_products, avg_training_time_per_product=historical_avg if historical_avg else 60.0 ) except Exception as est_error: logger.warning("Could not calculate intelligent estimate, using default", error=str(est_error)) estimated_duration_minutes = 15 # Default fallback # Calculate estimated completion time estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes) # Note: training.started event will be published by the trainer with accurate product count # We don't publish here to avoid duplicate events # Add enhanced background task background_tasks.add_task( execute_training_job_background, tenant_id=tenant_id, job_id=job_id, bakery_location=(40.4168, -3.7038), requested_start=request.start_date, requested_end=request.end_date, estimated_duration_minutes=estimated_duration_minutes ) # Return enhanced immediate success response response_data = { "job_id": job_id, "tenant_id": tenant_id, "status": "pending", "message": "Enhanced training job started successfully using repository pattern", "created_at": datetime.now(timezone.utc), "estimated_duration_minutes": estimated_duration_minutes, "training_results": { "total_products": 0, "successful_trainings": 0, "failed_trainings": 0, "products": [], "overall_training_time_seconds": 0.0 }, "data_summary": None, "completed_at": None, "error_details": None, "processing_metadata": { "background_task": True, "async_execution": True, "enhanced_features": True, "repository_pattern": True, "dependency_injection": True } } logger.info("Enhanced training job queued successfully", job_id=job_id, features=["repository-pattern", "dependency-injection", "enhanced-tracking"]) # Log audit event for training job creation try: from app.core.database import database_manager async with database_manager.get_session() as db: await audit_logger.log_event( db_session=db, tenant_id=tenant_id, user_id=current_user["user_id"], action=AuditAction.CREATE.value, resource_type="training_job", resource_id=job_id, severity=AuditSeverity.MEDIUM.value, description=f"Started training job (tier: {tier})", audit_metadata={ "job_id": job_id, "tier": tier, "estimated_dataset_size": estimated_dataset_size, "quota_usage": quota_result.get('current', 0) if quota_result else 0, "quota_limit": quota_limit if quota_limit else "unlimited" }, endpoint="/jobs", method="POST" ) except Exception as audit_error: logger.warning("Failed to log audit event", error=str(audit_error)) return TrainingJobResponse(**response_data) except HTTPException: raise except ValueError as e: if metrics: metrics.increment_counter("enhanced_training_validation_errors_total") logger.error("Enhanced training job validation error", error=str(e), tenant_id=tenant_id) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) ) except Exception as e: if metrics: metrics.increment_counter("enhanced_training_job_errors_total") logger.error("Failed to queue enhanced training job", error=str(e), tenant_id=tenant_id) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to start enhanced training job" ) async def execute_training_job_background( tenant_id: str, job_id: str, bakery_location: tuple, requested_start: Optional[datetime] = None, requested_end: Optional[datetime] = None, estimated_duration_minutes: int = 15 ): """ Enhanced background task that executes the training job using repository pattern. Enhanced features: - Repository pattern for all data operations - Enhanced error handling with structured logging - Transactional operations for data consistency - Comprehensive metrics tracking - Database connection pooling - Enhanced progress reporting """ logger.info("Enhanced background training job started", job_id=job_id, tenant_id=tenant_id, features=["repository-pattern", "enhanced-tracking"]) # Get enhanced training service with dependency injection database_manager = create_database_manager(settings.DATABASE_URL, "training-service") enhanced_training_service = EnhancedTrainingService(database_manager) try: # Create initial training log entry first await enhanced_training_service._update_job_status_repository( job_id=job_id, status="pending", progress=0, current_step="Starting enhanced training job", tenant_id=tenant_id ) # This will be published by the training service itself # when it starts execution training_config = { "job_id": job_id, "tenant_id": tenant_id, "bakery_location": { "latitude": bakery_location[0], "longitude": bakery_location[1] }, "requested_start": requested_start.isoformat() if requested_start else None, "requested_end": requested_end.isoformat() if requested_end else None, "estimated_duration_minutes": estimated_duration_minutes, "background_execution": True, "enhanced_features": True, "repository_pattern": True, "api_version": "enhanced_v1" } # Update job status using repository pattern await enhanced_training_service._update_job_status_repository( job_id=job_id, status="running", progress=0, current_step="Initializing enhanced training pipeline", tenant_id=tenant_id ) # Execute the enhanced training pipeline with repository pattern result = await enhanced_training_service.start_training_job( tenant_id=tenant_id, job_id=job_id, bakery_location=bakery_location, requested_start=requested_start, requested_end=requested_end ) # Note: Final status is already updated by start_training_job() via complete_training_log() # No need for redundant update here - it was causing duplicate log entries # Completion event is published by the training service logger.info("Enhanced background training job completed successfully", job_id=job_id, models_created=result.get('products_trained', 0), features=["repository-pattern", "enhanced-tracking"]) except Exception as training_error: logger.error("Enhanced training pipeline failed", job_id=job_id, error=str(training_error)) try: await enhanced_training_service._update_job_status_repository( job_id=job_id, status="failed", progress=0, current_step="Enhanced training failed", error_message=str(training_error), tenant_id=tenant_id ) except Exception as status_error: logger.error("Failed to update job status after training error", job_id=job_id, status_error=str(status_error)) # Failure event is published by the training service await publish_training_failed(job_id, tenant_id, str(training_error)) finally: logger.info("Enhanced background training job cleanup completed", job_id=job_id) @router.post( route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse) @require_user_role(['admin', 'owner']) @track_execution_time("enhanced_single_product_training_duration_seconds", "training-service") async def start_single_product_training( request: SingleProductTrainingRequest, tenant_id: str = Path(..., description="Tenant ID"), inventory_product_id: str = Path(..., description="Inventory product UUID"), background_tasks: BackgroundTasks = BackgroundTasks(), request_obj: Request = None, current_user: Dict[str, Any] = Depends(get_current_user_dep), enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service) ): """ Start enhanced training for a single product (Admin+ only). **RBAC:** Admin or Owner role required Enhanced features: - Repository pattern for data access - Enhanced error handling and validation - Metrics tracking - Transactional operations - Background execution to prevent blocking """ metrics = get_metrics_collector(request_obj) try: logger.info("Starting enhanced single product training", inventory_product_id=inventory_product_id, tenant_id=tenant_id) # CRITICAL FIX: Check if this product is currently being trained # This prevents duplicate training from rapid-click scenarios async with enhanced_training_service.database_manager.get_session() as check_session: from app.repositories.training_log_repository import TrainingLogRepository log_repo = TrainingLogRepository(check_session) # Check for active jobs for this specific product active_jobs = await log_repo.get_active_jobs(tenant_id=tenant_id) pending_jobs = await log_repo.get_logs_by_tenant( tenant_id=tenant_id, status="pending", limit=20 ) all_active = active_jobs + pending_jobs # Filter for jobs that include this specific product product_jobs = [ job for job in all_active if job.config and ( # Single product job for this product job.config.get("product_id") == inventory_product_id or # Tenant-wide job that would include this product job.config.get("job_type") == "tenant_training" ) ] if product_jobs: existing_job = product_jobs[0] logger.warning("Product training already in progress, rejecting duplicate request", existing_job_id=existing_job.job_id, tenant_id=tenant_id, inventory_product_id=inventory_product_id, status=existing_job.status) raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail={ "error": "Product training already in progress", "message": f"Product {inventory_product_id} is currently being trained in job {existing_job.job_id}", "existing_job_id": existing_job.job_id, "status": existing_job.status, "started_at": existing_job.created_at.isoformat() if existing_job.created_at else None } ) # No existing job, proceed with training # Record metrics if metrics: metrics.increment_counter("enhanced_single_product_training_total") # Generate enhanced job ID job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}" # CRITICAL FIX: Add initial training log entry await enhanced_training_service._update_job_status_repository( job_id=job_id, status="pending", progress=0, current_step="Initializing single product training", tenant_id=tenant_id ) # Add enhanced background task for single product training background_tasks.add_task( execute_single_product_training_background, tenant_id=tenant_id, inventory_product_id=inventory_product_id, job_id=job_id, bakery_location=request.bakery_location or (40.4168, -3.7038), database_manager=enhanced_training_service.database_manager ) # Return immediate response with job info response_data = { "job_id": job_id, "tenant_id": tenant_id, "status": "pending", "message": "Enhanced single product training started successfully", "created_at": datetime.now(timezone.utc), "estimated_duration_minutes": 15, # Default estimate for single product "training_results": { "total_products": 1, "successful_trainings": 0, "failed_trainings": 0, "products": [], "overall_training_time_seconds": 0.0 }, "data_summary": None, "completed_at": None, "error_details": None, "processing_metadata": { "background_task": True, "async_execution": True, "enhanced_features": True, "repository_pattern": True, "dependency_injection": True } } logger.info("Enhanced single product training queued successfully", inventory_product_id=inventory_product_id, job_id=job_id) if metrics: metrics.increment_counter("enhanced_single_product_training_queued_total") return TrainingJobResponse(**response_data) except ValueError as e: if metrics: metrics.increment_counter("enhanced_single_product_validation_errors_total") logger.error("Enhanced single product training validation error", error=str(e), inventory_product_id=inventory_product_id) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) ) except Exception as e: if metrics: metrics.increment_counter("enhanced_single_product_training_errors_total") logger.error("Enhanced single product training failed", error=str(e), inventory_product_id=inventory_product_id) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Enhanced single product training failed" ) async def execute_single_product_training_background( tenant_id: str, inventory_product_id: str, job_id: str, bakery_location: tuple, database_manager ): """ Enhanced background task that executes single product training using repository pattern. Uses a separate service instance to avoid session conflicts. """ logger.info("Enhanced background single product training started", job_id=job_id, tenant_id=tenant_id, inventory_product_id=inventory_product_id) # Create a new service instance with a fresh database session to avoid conflicts from app.services.training_service import EnhancedTrainingService fresh_training_service = EnhancedTrainingService(database_manager) try: # Update job status to running await fresh_training_service._update_job_status_repository( job_id=job_id, status="running", progress=0, current_step="Starting single product training", tenant_id=tenant_id ) # Execute the enhanced single product training with repository pattern result = await fresh_training_service.start_single_product_training( tenant_id=tenant_id, inventory_product_id=inventory_product_id, job_id=job_id, bakery_location=bakery_location ) logger.info("Enhanced background single product training completed successfully", job_id=job_id, inventory_product_id=inventory_product_id) except Exception as training_error: logger.error("Enhanced single product training failed", job_id=job_id, inventory_product_id=inventory_product_id, error=str(training_error)) try: await fresh_training_service._update_job_status_repository( job_id=job_id, status="failed", progress=0, current_step="Single product training failed", error_message=str(training_error), tenant_id=tenant_id ) except Exception as status_error: logger.error("Failed to update job status after training error", job_id=job_id, status_error=str(status_error)) finally: logger.info("Enhanced background single product training cleanup completed", job_id=job_id, inventory_product_id=inventory_product_id) @router.get("/health") async def health_check(): """Health check endpoint for the training operations""" return { "status": "healthy", "service": "training-operations", "version": "3.0.0", "features": [ "repository-pattern", "dependency-injection", "enhanced-error-handling", "metrics-tracking", "transactional-operations" ], "timestamp": datetime.now().isoformat() } # ============================================================================ # Tenant Data Deletion Operations (Internal Service Only) # ============================================================================ @router.delete( route_builder.build_base_route("tenant/{tenant_id}", include_tenant_prefix=False), response_model=dict ) @service_only_access async def delete_tenant_data( tenant_id: str = Path(..., description="Tenant ID to delete data for"), current_user: dict = Depends(get_current_user_dep) ): """ Delete all training data for a tenant (Internal service only) This endpoint is called by the orchestrator during tenant deletion. It permanently deletes all training-related data including: - Trained models (all versions) - Model artifacts (files and metadata) - Training logs and job history - Model performance metrics - Training job queue entries - Audit logs **WARNING**: This operation is irreversible! **NOTE**: Physical model files (.pkl) should be cleaned up separately Returns: Deletion summary with counts of deleted records """ from app.services.tenant_deletion_service import TrainingTenantDeletionService from app.core.config import settings try: logger.info("training.tenant_deletion.api_called", tenant_id=tenant_id) db_manager = create_database_manager(settings.DATABASE_URL, "training") async with db_manager.get_session() as session: deletion_service = TrainingTenantDeletionService(session) result = await deletion_service.safe_delete_tenant_data(tenant_id) if not result.success: raise HTTPException( status_code=500, detail=f"Tenant data deletion failed: {', '.join(result.errors)}" ) return { "message": "Tenant data deletion completed successfully", "note": "Physical model files should be cleaned up separately from storage", "summary": result.to_dict() } except HTTPException: raise except Exception as e: logger.error("training.tenant_deletion.api_error", tenant_id=tenant_id, error=str(e), exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to delete tenant data: {str(e)}" ) @router.get( route_builder.build_base_route("tenant/{tenant_id}/deletion-preview", include_tenant_prefix=False), response_model=dict ) @service_only_access async def preview_tenant_data_deletion( tenant_id: str = Path(..., description="Tenant ID to preview deletion for"), current_user: dict = Depends(get_current_user_dep) ): """ Preview what data would be deleted for a tenant (dry-run) This endpoint shows counts of all data that would be deleted without actually deleting anything. Useful for: - Confirming deletion scope before execution - Auditing and compliance - Troubleshooting Returns: Dictionary with entity names and their counts """ from app.services.tenant_deletion_service import TrainingTenantDeletionService from app.core.config import settings try: logger.info("training.tenant_deletion.preview_called", tenant_id=tenant_id) db_manager = create_database_manager(settings.DATABASE_URL, "training") async with db_manager.get_session() as session: deletion_service = TrainingTenantDeletionService(session) preview = await deletion_service.get_tenant_data_preview(tenant_id) total_records = sum(preview.values()) return { "tenant_id": tenant_id, "service": "training", "preview": preview, "total_records": total_records, "note": "Physical model files (.pkl, metadata) are not counted here", "warning": "These records will be permanently deleted and cannot be recovered" } except Exception as e: logger.error("training.tenant_deletion.preview_error", tenant_id=tenant_id, error=str(e), exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to preview tenant data deletion: {str(e)}" )