Files
bakery-ia/services/training/app/api/training_operations.py

822 lines
32 KiB
Python

"""
Training Operations API - BUSINESS logic
Handles training job execution and metrics
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path
from typing import Optional, Dict, Any
import structlog
from datetime import datetime, timezone
import uuid
import shared.redis_utils
from sqlalchemy.ext.asyncio import AsyncSession
from shared.routing import RouteBuilder
from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from shared.database.base import create_database_manager
from shared.auth.decorators import get_current_user_dep
from shared.auth.access_control import require_user_role, admin_role_required, service_only_access
from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction
from shared.subscription.plans import (
get_training_job_quota,
get_dataset_size_limit
)
from app.services.training_service import EnhancedTrainingService
from app.schemas.training import (
TrainingJobRequest,
SingleProductTrainingRequest,
TrainingJobResponse
)
from app.utils.time_estimation import (
calculate_initial_estimate,
calculate_estimated_completion_time,
get_historical_average_estimate
)
from app.services.training_events import (
publish_training_started,
publish_training_completed,
publish_training_failed
)
from app.core.config import settings
from app.core.database import get_db
from app.models import AuditLog
logger = structlog.get_logger()
route_builder = RouteBuilder('training')
router = APIRouter(tags=["training-operations"])
# Initialize audit logger
audit_logger = create_audit_logger("training-service", AuditLog)
# Redis client for rate limiting
_redis_client = None
async def get_training_redis_client():
"""Get or create Redis client for rate limiting"""
global _redis_client
if _redis_client is None:
# Initialize Redis if not already done
try:
from app.core.config import settings
_redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
except:
# Fallback to getting the client directly (if already initialized elsewhere)
_redis_client = await shared.redis_utils.get_redis_client()
return _redis_client
async def get_rate_limiter():
"""Dependency for rate limiter"""
redis_client = await get_training_redis_client()
return create_rate_limiter(redis_client)
def get_enhanced_training_service():
"""Dependency injection for EnhancedTrainingService"""
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
return EnhancedTrainingService(database_manager)
@router.post(
route_builder.build_base_route("jobs"), response_model=TrainingJobResponse)
@require_user_role(['admin', 'owner'])
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
async def start_training_job(
request: TrainingJobRequest,
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
request_obj: Request = None,
current_user: Dict[str, Any] = Depends(get_current_user_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service),
rate_limiter = Depends(get_rate_limiter),
db: AsyncSession = Depends(get_db)
):
"""
Start a new training job for all tenant products (Admin+ only, quota enforced).
**RBAC:** Admin or Owner role required
**Quotas:**
- Starter: 1 training job/day, max 1,000 rows
- Professional: 5 training jobs/day, max 10,000 rows
- Enterprise: Unlimited jobs, unlimited rows
Enhanced immediate response pattern:
1. Validate subscription tier and quotas
2. Validate request with enhanced validation
3. Create job record using repository pattern
4. Return 200 with enhanced job details
5. Execute enhanced training in background with repository tracking
Enhanced features:
- Repository pattern for data access
- Quota enforcement by subscription tier
- Audit logging for all operations
- Enhanced error handling and logging
- Metrics tracking and monitoring
- Transactional operations
"""
metrics = get_metrics_collector(request_obj)
# Get subscription tier and enforce quotas
tier = current_user.get('subscription_tier', 'starter')
# Estimate dataset size (this should come from the request or be calculated)
# For now, we'll assume a reasonable estimate
estimated_dataset_size = request.estimated_rows if hasattr(request, 'estimated_rows') else 500
# Initialize variables for later use
quota_result = None
quota_limit = None
try:
# Validate dataset size limits
await rate_limiter.validate_dataset_size(
tenant_id, estimated_dataset_size, tier
)
# Check daily training job quota
quota_limit = get_training_job_quota(tier)
quota_result = await rate_limiter.check_and_increment_quota(
tenant_id,
"training_jobs",
quota_limit,
period=86400 # 24 hours
)
logger.info("Training job quota check passed",
tenant_id=tenant_id,
tier=tier,
current_usage=quota_result.get('current', 0) if quota_result else 0,
limit=quota_limit)
except HTTPException:
# Quota or validation error - re-raise
raise
except Exception as quota_error:
logger.error("Quota validation failed", error=str(quota_error))
# Continue with job creation but log the error
try:
# CRITICAL FIX: Check for existing running jobs before starting new one
# This prevents duplicate tenant-level training jobs
async with enhanced_training_service.database_manager.get_session() as check_session:
from app.repositories.training_log_repository import TrainingLogRepository
log_repo = TrainingLogRepository(check_session)
# Check for active jobs (running or pending)
active_jobs = await log_repo.get_active_jobs(tenant_id=tenant_id)
pending_jobs = await log_repo.get_logs_by_tenant(
tenant_id=tenant_id,
status="pending",
limit=10
)
all_active = active_jobs + pending_jobs
if all_active:
# Training job already in progress, return existing job info
existing_job = all_active[0]
logger.info("Training job already in progress, returning existing job",
existing_job_id=existing_job.job_id,
tenant_id=tenant_id,
status=existing_job.status)
return TrainingJobResponse(
job_id=existing_job.job_id,
tenant_id=tenant_id,
status=existing_job.status,
message=f"Training job already in progress (started {existing_job.created_at.isoformat() if existing_job.created_at else 'recently'})",
created_at=existing_job.created_at or datetime.now(timezone.utc),
estimated_duration_minutes=existing_job.config.get("estimated_duration_minutes", 15) if existing_job.config else 15,
training_results={
"total_products": 0,
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
"overall_training_time_seconds": 0.0
},
data_summary=None,
completed_at=None,
error_details=None,
processing_metadata={
"background_task": True,
"async_execution": True,
"existing_job": True,
"deduplication": True
}
)
# No existing job, proceed with creating new one
# Generate enhanced job ID
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info("Creating enhanced training job using repository pattern",
job_id=job_id,
tenant_id=tenant_id)
# Record job creation metrics
if metrics:
metrics.increment_counter("enhanced_training_jobs_created_total")
# Calculate intelligent time estimate
# We don't know exact product count yet, so use historical average or estimate
try:
# Try to get historical average for this tenant
historical_avg = await get_historical_average_estimate(db, tenant_id)
# If no historical data, estimate based on typical product count (10-20 products)
estimated_products = 15 # Conservative estimate
estimated_duration_minutes = calculate_initial_estimate(
total_products=estimated_products,
avg_training_time_per_product=historical_avg if historical_avg else 60.0
)
except Exception as est_error:
logger.warning("Could not calculate intelligent estimate, using default",
error=str(est_error))
estimated_duration_minutes = 15 # Default fallback
# Calculate estimated completion time
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
# Note: training.started event will be published by the trainer with accurate product count
# We don't publish here to avoid duplicate events
# Add enhanced background task
background_tasks.add_task(
execute_training_job_background,
tenant_id=tenant_id,
job_id=job_id,
bakery_location=(40.4168, -3.7038),
requested_start=request.start_date,
requested_end=request.end_date,
estimated_duration_minutes=estimated_duration_minutes
)
# Return enhanced immediate success response
response_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending",
"message": "Enhanced training job started successfully using repository pattern",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": estimated_duration_minutes,
"training_results": {
"total_products": 0,
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
"overall_training_time_seconds": 0.0
},
"data_summary": None,
"completed_at": None,
"error_details": None,
"processing_metadata": {
"background_task": True,
"async_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"dependency_injection": True
}
}
logger.info("Enhanced training job queued successfully",
job_id=job_id,
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
# Log audit event for training job creation
try:
from app.core.database import database_manager
async with database_manager.get_session() as db:
await audit_logger.log_event(
db_session=db,
tenant_id=tenant_id,
user_id=current_user["user_id"],
action=AuditAction.CREATE.value,
resource_type="training_job",
resource_id=job_id,
severity=AuditSeverity.MEDIUM.value,
description=f"Started training job (tier: {tier})",
audit_metadata={
"job_id": job_id,
"tier": tier,
"estimated_dataset_size": estimated_dataset_size,
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
"quota_limit": quota_limit if quota_limit else "unlimited"
},
endpoint="/jobs",
method="POST"
)
except Exception as audit_error:
logger.warning("Failed to log audit event", error=str(audit_error))
return TrainingJobResponse(**response_data)
except HTTPException:
raise
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_training_validation_errors_total")
logger.error("Enhanced training job validation error",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_training_job_errors_total")
logger.error("Failed to queue enhanced training job",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to start enhanced training job"
)
async def execute_training_job_background(
tenant_id: str,
job_id: str,
bakery_location: tuple,
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None,
estimated_duration_minutes: int = 15
):
"""
Enhanced background task that executes the training job using repository pattern.
Enhanced features:
- Repository pattern for all data operations
- Enhanced error handling with structured logging
- Transactional operations for data consistency
- Comprehensive metrics tracking
- Database connection pooling
- Enhanced progress reporting
"""
logger.info("Enhanced background training job started",
job_id=job_id,
tenant_id=tenant_id,
features=["repository-pattern", "enhanced-tracking"])
# Get enhanced training service with dependency injection
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
enhanced_training_service = EnhancedTrainingService(database_manager)
try:
# Create initial training log entry first
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="pending",
progress=0,
current_step="Starting enhanced training job",
tenant_id=tenant_id
)
# This will be published by the training service itself
# when it starts execution
training_config = {
"job_id": job_id,
"tenant_id": tenant_id,
"bakery_location": {
"latitude": bakery_location[0],
"longitude": bakery_location[1]
},
"requested_start": requested_start.isoformat() if requested_start else None,
"requested_end": requested_end.isoformat() if requested_end else None,
"estimated_duration_minutes": estimated_duration_minutes,
"background_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"api_version": "enhanced_v1"
}
# Update job status using repository pattern
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Initializing enhanced training pipeline",
tenant_id=tenant_id
)
# Execute the enhanced training pipeline with repository pattern
result = await enhanced_training_service.start_training_job(
tenant_id=tenant_id,
job_id=job_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end
)
# Note: Final status is already updated by start_training_job() via complete_training_log()
# No need for redundant update here - it was causing duplicate log entries
# Completion event is published by the training service
logger.info("Enhanced background training job completed successfully",
job_id=job_id,
models_created=result.get('products_trained', 0),
features=["repository-pattern", "enhanced-tracking"])
except Exception as training_error:
logger.error("Enhanced training pipeline failed",
job_id=job_id,
error=str(training_error))
try:
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Enhanced training failed",
error_message=str(training_error),
tenant_id=tenant_id
)
except Exception as status_error:
logger.error("Failed to update job status after training error",
job_id=job_id,
status_error=str(status_error))
# Failure event is published by the training service
await publish_training_failed(job_id, tenant_id, str(training_error))
finally:
logger.info("Enhanced background training job cleanup completed",
job_id=job_id)
@router.post(
route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse)
@require_user_role(['admin', 'owner'])
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
async def start_single_product_training(
request: SingleProductTrainingRequest,
tenant_id: str = Path(..., description="Tenant ID"),
inventory_product_id: str = Path(..., description="Inventory product UUID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
request_obj: Request = None,
current_user: Dict[str, Any] = Depends(get_current_user_dep),
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
):
"""
Start enhanced training for a single product (Admin+ only).
**RBAC:** Admin or Owner role required
Enhanced features:
- Repository pattern for data access
- Enhanced error handling and validation
- Metrics tracking
- Transactional operations
- Background execution to prevent blocking
"""
metrics = get_metrics_collector(request_obj)
try:
logger.info("Starting enhanced single product training",
inventory_product_id=inventory_product_id,
tenant_id=tenant_id)
# CRITICAL FIX: Check if this product is currently being trained
# This prevents duplicate training from rapid-click scenarios
async with enhanced_training_service.database_manager.get_session() as check_session:
from app.repositories.training_log_repository import TrainingLogRepository
log_repo = TrainingLogRepository(check_session)
# Check for active jobs for this specific product
active_jobs = await log_repo.get_active_jobs(tenant_id=tenant_id)
pending_jobs = await log_repo.get_logs_by_tenant(
tenant_id=tenant_id,
status="pending",
limit=20
)
all_active = active_jobs + pending_jobs
# Filter for jobs that include this specific product
product_jobs = [
job for job in all_active
if job.config and (
# Single product job for this product
job.config.get("product_id") == inventory_product_id or
# Tenant-wide job that would include this product
job.config.get("job_type") == "tenant_training"
)
]
if product_jobs:
existing_job = product_jobs[0]
logger.warning("Product training already in progress, rejecting duplicate request",
existing_job_id=existing_job.job_id,
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
status=existing_job.status)
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"error": "Product training already in progress",
"message": f"Product {inventory_product_id} is currently being trained in job {existing_job.job_id}",
"existing_job_id": existing_job.job_id,
"status": existing_job.status,
"started_at": existing_job.created_at.isoformat() if existing_job.created_at else None
}
)
# No existing job, proceed with training
# Record metrics
if metrics:
metrics.increment_counter("enhanced_single_product_training_total")
# Generate enhanced job ID
job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
# CRITICAL FIX: Add initial training log entry
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="pending",
progress=0,
current_step="Initializing single product training",
tenant_id=tenant_id
)
# Add enhanced background task for single product training
background_tasks.add_task(
execute_single_product_training_background,
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id,
bakery_location=request.bakery_location or (40.4168, -3.7038),
database_manager=enhanced_training_service.database_manager
)
# Return immediate response with job info
response_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending",
"message": "Enhanced single product training started successfully",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": 15, # Default estimate for single product
"training_results": {
"total_products": 1,
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
"overall_training_time_seconds": 0.0
},
"data_summary": None,
"completed_at": None,
"error_details": None,
"processing_metadata": {
"background_task": True,
"async_execution": True,
"enhanced_features": True,
"repository_pattern": True,
"dependency_injection": True
}
}
logger.info("Enhanced single product training queued successfully",
inventory_product_id=inventory_product_id,
job_id=job_id)
if metrics:
metrics.increment_counter("enhanced_single_product_training_queued_total")
return TrainingJobResponse(**response_data)
except ValueError as e:
if metrics:
metrics.increment_counter("enhanced_single_product_validation_errors_total")
logger.error("Enhanced single product training validation error",
error=str(e),
inventory_product_id=inventory_product_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
if metrics:
metrics.increment_counter("enhanced_single_product_training_errors_total")
logger.error("Enhanced single product training failed",
error=str(e),
inventory_product_id=inventory_product_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Enhanced single product training failed"
)
async def execute_single_product_training_background(
tenant_id: str,
inventory_product_id: str,
job_id: str,
bakery_location: tuple,
database_manager
):
"""
Enhanced background task that executes single product training using repository pattern.
Uses a separate service instance to avoid session conflicts.
"""
logger.info("Enhanced background single product training started",
job_id=job_id,
tenant_id=tenant_id,
inventory_product_id=inventory_product_id)
# Create a new service instance with a fresh database session to avoid conflicts
from app.services.training_service import EnhancedTrainingService
fresh_training_service = EnhancedTrainingService(database_manager)
try:
# Update job status to running
await fresh_training_service._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Starting single product training",
tenant_id=tenant_id
)
# Execute the enhanced single product training with repository pattern
result = await fresh_training_service.start_single_product_training(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id,
bakery_location=bakery_location
)
logger.info("Enhanced background single product training completed successfully",
job_id=job_id,
inventory_product_id=inventory_product_id)
except Exception as training_error:
logger.error("Enhanced single product training failed",
job_id=job_id,
inventory_product_id=inventory_product_id,
error=str(training_error))
try:
await fresh_training_service._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Single product training failed",
error_message=str(training_error),
tenant_id=tenant_id
)
except Exception as status_error:
logger.error("Failed to update job status after training error",
job_id=job_id,
status_error=str(status_error))
finally:
logger.info("Enhanced background single product training cleanup completed",
job_id=job_id,
inventory_product_id=inventory_product_id)
@router.get("/health")
async def health_check():
"""Health check endpoint for the training operations"""
return {
"status": "healthy",
"service": "training-operations",
"version": "3.0.0",
"features": [
"repository-pattern",
"dependency-injection",
"enhanced-error-handling",
"metrics-tracking",
"transactional-operations"
],
"timestamp": datetime.now().isoformat()
}
# ============================================================================
# Tenant Data Deletion Operations (Internal Service Only)
# ============================================================================
@router.delete(
route_builder.build_base_route("tenant/{tenant_id}", include_tenant_prefix=False),
response_model=dict
)
@service_only_access
async def delete_tenant_data(
tenant_id: str = Path(..., description="Tenant ID to delete data for"),
current_user: dict = Depends(get_current_user_dep)
):
"""
Delete all training data for a tenant (Internal service only)
This endpoint is called by the orchestrator during tenant deletion.
It permanently deletes all training-related data including:
- Trained models (all versions)
- Model artifacts (files and metadata)
- Training logs and job history
- Model performance metrics
- Training job queue entries
- Audit logs
**WARNING**: This operation is irreversible!
**NOTE**: Physical model files (.pkl) should be cleaned up separately
Returns:
Deletion summary with counts of deleted records
"""
from app.services.tenant_deletion_service import TrainingTenantDeletionService
from app.core.config import settings
try:
logger.info("training.tenant_deletion.api_called", tenant_id=tenant_id)
db_manager = create_database_manager(settings.DATABASE_URL, "training")
async with db_manager.get_session() as session:
deletion_service = TrainingTenantDeletionService(session)
result = await deletion_service.safe_delete_tenant_data(tenant_id)
if not result.success:
raise HTTPException(
status_code=500,
detail=f"Tenant data deletion failed: {', '.join(result.errors)}"
)
return {
"message": "Tenant data deletion completed successfully",
"note": "Physical model files should be cleaned up separately from storage",
"summary": result.to_dict()
}
except HTTPException:
raise
except Exception as e:
logger.error("training.tenant_deletion.api_error",
tenant_id=tenant_id,
error=str(e),
exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Failed to delete tenant data: {str(e)}"
)
@router.get(
route_builder.build_base_route("tenant/{tenant_id}/deletion-preview", include_tenant_prefix=False),
response_model=dict
)
@service_only_access
async def preview_tenant_data_deletion(
tenant_id: str = Path(..., description="Tenant ID to preview deletion for"),
current_user: dict = Depends(get_current_user_dep)
):
"""
Preview what data would be deleted for a tenant (dry-run)
This endpoint shows counts of all data that would be deleted
without actually deleting anything. Useful for:
- Confirming deletion scope before execution
- Auditing and compliance
- Troubleshooting
Returns:
Dictionary with entity names and their counts
"""
from app.services.tenant_deletion_service import TrainingTenantDeletionService
from app.core.config import settings
try:
logger.info("training.tenant_deletion.preview_called", tenant_id=tenant_id)
db_manager = create_database_manager(settings.DATABASE_URL, "training")
async with db_manager.get_session() as session:
deletion_service = TrainingTenantDeletionService(session)
preview = await deletion_service.get_tenant_data_preview(tenant_id)
total_records = sum(preview.values())
return {
"tenant_id": tenant_id,
"service": "training",
"preview": preview,
"total_records": total_records,
"note": "Physical model files (.pkl, metadata) are not counted here",
"warning": "These records will be permanently deleted and cannot be recovered"
}
except Exception as e:
logger.error("training.tenant_deletion.preview_error",
tenant_id=tenant_id,
error=str(e),
exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Failed to preview tenant data deletion: {str(e)}"
)