822 lines
32 KiB
Python
822 lines
32 KiB
Python
"""
|
|
Training Operations API - BUSINESS logic
|
|
Handles training job execution and metrics
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path
|
|
from typing import Optional, Dict, Any
|
|
import structlog
|
|
from datetime import datetime, timezone
|
|
import uuid
|
|
import shared.redis_utils
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from shared.routing import RouteBuilder
|
|
from shared.monitoring.decorators import track_execution_time
|
|
from shared.monitoring.metrics import get_metrics_collector
|
|
from shared.database.base import create_database_manager
|
|
from shared.auth.decorators import get_current_user_dep
|
|
from shared.auth.access_control import require_user_role, admin_role_required, service_only_access
|
|
from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction
|
|
from shared.subscription.plans import (
|
|
get_training_job_quota,
|
|
get_dataset_size_limit
|
|
)
|
|
|
|
from app.services.training_service import EnhancedTrainingService
|
|
from app.schemas.training import (
|
|
TrainingJobRequest,
|
|
SingleProductTrainingRequest,
|
|
TrainingJobResponse
|
|
)
|
|
from app.utils.time_estimation import (
|
|
calculate_initial_estimate,
|
|
calculate_estimated_completion_time,
|
|
get_historical_average_estimate
|
|
)
|
|
from app.services.training_events import (
|
|
publish_training_started,
|
|
publish_training_completed,
|
|
publish_training_failed
|
|
)
|
|
from app.core.config import settings
|
|
from app.core.database import get_db
|
|
from app.models import AuditLog
|
|
|
|
logger = structlog.get_logger()
|
|
route_builder = RouteBuilder('training')
|
|
|
|
router = APIRouter(tags=["training-operations"])
|
|
|
|
# Initialize audit logger
|
|
audit_logger = create_audit_logger("training-service", AuditLog)
|
|
|
|
# Redis client for rate limiting
|
|
_redis_client = None
|
|
|
|
async def get_training_redis_client():
|
|
"""Get or create Redis client for rate limiting"""
|
|
global _redis_client
|
|
if _redis_client is None:
|
|
# Initialize Redis if not already done
|
|
try:
|
|
from app.core.config import settings
|
|
_redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
|
except:
|
|
# Fallback to getting the client directly (if already initialized elsewhere)
|
|
_redis_client = await shared.redis_utils.get_redis_client()
|
|
return _redis_client
|
|
|
|
async def get_rate_limiter():
|
|
"""Dependency for rate limiter"""
|
|
redis_client = await get_training_redis_client()
|
|
return create_rate_limiter(redis_client)
|
|
|
|
def get_enhanced_training_service():
|
|
"""Dependency injection for EnhancedTrainingService"""
|
|
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
|
return EnhancedTrainingService(database_manager)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_base_route("jobs"), response_model=TrainingJobResponse)
|
|
@require_user_role(['admin', 'owner'])
|
|
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
|
|
async def start_training_job(
|
|
request: TrainingJobRequest,
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
background_tasks: BackgroundTasks = BackgroundTasks(),
|
|
request_obj: Request = None,
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service),
|
|
rate_limiter = Depends(get_rate_limiter),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Start a new training job for all tenant products (Admin+ only, quota enforced).
|
|
|
|
**RBAC:** Admin or Owner role required
|
|
**Quotas:**
|
|
- Starter: 1 training job/day, max 1,000 rows
|
|
- Professional: 5 training jobs/day, max 10,000 rows
|
|
- Enterprise: Unlimited jobs, unlimited rows
|
|
|
|
Enhanced immediate response pattern:
|
|
1. Validate subscription tier and quotas
|
|
2. Validate request with enhanced validation
|
|
3. Create job record using repository pattern
|
|
4. Return 200 with enhanced job details
|
|
5. Execute enhanced training in background with repository tracking
|
|
|
|
Enhanced features:
|
|
- Repository pattern for data access
|
|
- Quota enforcement by subscription tier
|
|
- Audit logging for all operations
|
|
- Enhanced error handling and logging
|
|
- Metrics tracking and monitoring
|
|
- Transactional operations
|
|
"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
# Get subscription tier and enforce quotas
|
|
tier = current_user.get('subscription_tier', 'starter')
|
|
|
|
# Estimate dataset size (this should come from the request or be calculated)
|
|
# For now, we'll assume a reasonable estimate
|
|
estimated_dataset_size = request.estimated_rows if hasattr(request, 'estimated_rows') else 500
|
|
|
|
# Initialize variables for later use
|
|
quota_result = None
|
|
quota_limit = None
|
|
|
|
try:
|
|
# Validate dataset size limits
|
|
await rate_limiter.validate_dataset_size(
|
|
tenant_id, estimated_dataset_size, tier
|
|
)
|
|
|
|
# Check daily training job quota
|
|
quota_limit = get_training_job_quota(tier)
|
|
quota_result = await rate_limiter.check_and_increment_quota(
|
|
tenant_id,
|
|
"training_jobs",
|
|
quota_limit,
|
|
period=86400 # 24 hours
|
|
)
|
|
|
|
logger.info("Training job quota check passed",
|
|
tenant_id=tenant_id,
|
|
tier=tier,
|
|
current_usage=quota_result.get('current', 0) if quota_result else 0,
|
|
limit=quota_limit)
|
|
|
|
except HTTPException:
|
|
# Quota or validation error - re-raise
|
|
raise
|
|
except Exception as quota_error:
|
|
logger.error("Quota validation failed", error=str(quota_error))
|
|
# Continue with job creation but log the error
|
|
|
|
try:
|
|
# CRITICAL FIX: Check for existing running jobs before starting new one
|
|
# This prevents duplicate tenant-level training jobs
|
|
async with enhanced_training_service.database_manager.get_session() as check_session:
|
|
from app.repositories.training_log_repository import TrainingLogRepository
|
|
log_repo = TrainingLogRepository(check_session)
|
|
|
|
# Check for active jobs (running or pending)
|
|
active_jobs = await log_repo.get_active_jobs(tenant_id=tenant_id)
|
|
pending_jobs = await log_repo.get_logs_by_tenant(
|
|
tenant_id=tenant_id,
|
|
status="pending",
|
|
limit=10
|
|
)
|
|
|
|
all_active = active_jobs + pending_jobs
|
|
|
|
if all_active:
|
|
# Training job already in progress, return existing job info
|
|
existing_job = all_active[0]
|
|
logger.info("Training job already in progress, returning existing job",
|
|
existing_job_id=existing_job.job_id,
|
|
tenant_id=tenant_id,
|
|
status=existing_job.status)
|
|
|
|
return TrainingJobResponse(
|
|
job_id=existing_job.job_id,
|
|
tenant_id=tenant_id,
|
|
status=existing_job.status,
|
|
message=f"Training job already in progress (started {existing_job.created_at.isoformat() if existing_job.created_at else 'recently'})",
|
|
created_at=existing_job.created_at or datetime.now(timezone.utc),
|
|
estimated_duration_minutes=existing_job.config.get("estimated_duration_minutes", 15) if existing_job.config else 15,
|
|
training_results={
|
|
"total_products": 0,
|
|
"successful_trainings": 0,
|
|
"failed_trainings": 0,
|
|
"products": [],
|
|
"overall_training_time_seconds": 0.0
|
|
},
|
|
data_summary=None,
|
|
completed_at=None,
|
|
error_details=None,
|
|
processing_metadata={
|
|
"background_task": True,
|
|
"async_execution": True,
|
|
"existing_job": True,
|
|
"deduplication": True
|
|
}
|
|
)
|
|
|
|
# No existing job, proceed with creating new one
|
|
# Generate enhanced job ID
|
|
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
|
|
|
logger.info("Creating enhanced training job using repository pattern",
|
|
job_id=job_id,
|
|
tenant_id=tenant_id)
|
|
|
|
# Record job creation metrics
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_training_jobs_created_total")
|
|
|
|
# Calculate intelligent time estimate
|
|
# We don't know exact product count yet, so use historical average or estimate
|
|
try:
|
|
# Try to get historical average for this tenant
|
|
historical_avg = await get_historical_average_estimate(db, tenant_id)
|
|
|
|
# If no historical data, estimate based on typical product count (10-20 products)
|
|
estimated_products = 15 # Conservative estimate
|
|
estimated_duration_minutes = calculate_initial_estimate(
|
|
total_products=estimated_products,
|
|
avg_training_time_per_product=historical_avg if historical_avg else 60.0
|
|
)
|
|
except Exception as est_error:
|
|
logger.warning("Could not calculate intelligent estimate, using default",
|
|
error=str(est_error))
|
|
estimated_duration_minutes = 15 # Default fallback
|
|
|
|
# Calculate estimated completion time
|
|
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
|
|
|
|
# Note: training.started event will be published by the trainer with accurate product count
|
|
# We don't publish here to avoid duplicate events
|
|
|
|
# Add enhanced background task
|
|
background_tasks.add_task(
|
|
execute_training_job_background,
|
|
tenant_id=tenant_id,
|
|
job_id=job_id,
|
|
bakery_location=(40.4168, -3.7038),
|
|
requested_start=request.start_date,
|
|
requested_end=request.end_date,
|
|
estimated_duration_minutes=estimated_duration_minutes
|
|
)
|
|
|
|
# Return enhanced immediate success response
|
|
response_data = {
|
|
"job_id": job_id,
|
|
"tenant_id": tenant_id,
|
|
"status": "pending",
|
|
"message": "Enhanced training job started successfully using repository pattern",
|
|
"created_at": datetime.now(timezone.utc),
|
|
"estimated_duration_minutes": estimated_duration_minutes,
|
|
"training_results": {
|
|
"total_products": 0,
|
|
"successful_trainings": 0,
|
|
"failed_trainings": 0,
|
|
"products": [],
|
|
"overall_training_time_seconds": 0.0
|
|
},
|
|
"data_summary": None,
|
|
"completed_at": None,
|
|
"error_details": None,
|
|
"processing_metadata": {
|
|
"background_task": True,
|
|
"async_execution": True,
|
|
"enhanced_features": True,
|
|
"repository_pattern": True,
|
|
"dependency_injection": True
|
|
}
|
|
}
|
|
|
|
logger.info("Enhanced training job queued successfully",
|
|
job_id=job_id,
|
|
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
|
|
|
|
# Log audit event for training job creation
|
|
try:
|
|
from app.core.database import database_manager
|
|
async with database_manager.get_session() as db:
|
|
await audit_logger.log_event(
|
|
db_session=db,
|
|
tenant_id=tenant_id,
|
|
user_id=current_user["user_id"],
|
|
action=AuditAction.CREATE.value,
|
|
resource_type="training_job",
|
|
resource_id=job_id,
|
|
severity=AuditSeverity.MEDIUM.value,
|
|
description=f"Started training job (tier: {tier})",
|
|
audit_metadata={
|
|
"job_id": job_id,
|
|
"tier": tier,
|
|
"estimated_dataset_size": estimated_dataset_size,
|
|
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
|
|
"quota_limit": quota_limit if quota_limit else "unlimited"
|
|
},
|
|
endpoint="/jobs",
|
|
method="POST"
|
|
)
|
|
except Exception as audit_error:
|
|
logger.warning("Failed to log audit event", error=str(audit_error))
|
|
|
|
return TrainingJobResponse(**response_data)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except ValueError as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_training_validation_errors_total")
|
|
logger.error("Enhanced training job validation error",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_training_job_errors_total")
|
|
logger.error("Failed to queue enhanced training job",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to start enhanced training job"
|
|
)
|
|
|
|
|
|
async def execute_training_job_background(
|
|
tenant_id: str,
|
|
job_id: str,
|
|
bakery_location: tuple,
|
|
requested_start: Optional[datetime] = None,
|
|
requested_end: Optional[datetime] = None,
|
|
estimated_duration_minutes: int = 15
|
|
):
|
|
"""
|
|
Enhanced background task that executes the training job using repository pattern.
|
|
|
|
Enhanced features:
|
|
- Repository pattern for all data operations
|
|
- Enhanced error handling with structured logging
|
|
- Transactional operations for data consistency
|
|
- Comprehensive metrics tracking
|
|
- Database connection pooling
|
|
- Enhanced progress reporting
|
|
"""
|
|
|
|
logger.info("Enhanced background training job started",
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
features=["repository-pattern", "enhanced-tracking"])
|
|
|
|
# Get enhanced training service with dependency injection
|
|
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
|
enhanced_training_service = EnhancedTrainingService(database_manager)
|
|
|
|
try:
|
|
# Create initial training log entry first
|
|
await enhanced_training_service._update_job_status_repository(
|
|
job_id=job_id,
|
|
status="pending",
|
|
progress=0,
|
|
current_step="Starting enhanced training job",
|
|
tenant_id=tenant_id
|
|
)
|
|
|
|
# This will be published by the training service itself
|
|
# when it starts execution
|
|
|
|
training_config = {
|
|
"job_id": job_id,
|
|
"tenant_id": tenant_id,
|
|
"bakery_location": {
|
|
"latitude": bakery_location[0],
|
|
"longitude": bakery_location[1]
|
|
},
|
|
"requested_start": requested_start.isoformat() if requested_start else None,
|
|
"requested_end": requested_end.isoformat() if requested_end else None,
|
|
"estimated_duration_minutes": estimated_duration_minutes,
|
|
"background_execution": True,
|
|
"enhanced_features": True,
|
|
"repository_pattern": True,
|
|
"api_version": "enhanced_v1"
|
|
}
|
|
|
|
# Update job status using repository pattern
|
|
await enhanced_training_service._update_job_status_repository(
|
|
job_id=job_id,
|
|
status="running",
|
|
progress=0,
|
|
current_step="Initializing enhanced training pipeline",
|
|
tenant_id=tenant_id
|
|
)
|
|
|
|
# Execute the enhanced training pipeline with repository pattern
|
|
result = await enhanced_training_service.start_training_job(
|
|
tenant_id=tenant_id,
|
|
job_id=job_id,
|
|
bakery_location=bakery_location,
|
|
requested_start=requested_start,
|
|
requested_end=requested_end
|
|
)
|
|
|
|
# Note: Final status is already updated by start_training_job() via complete_training_log()
|
|
# No need for redundant update here - it was causing duplicate log entries
|
|
|
|
# Completion event is published by the training service
|
|
|
|
logger.info("Enhanced background training job completed successfully",
|
|
job_id=job_id,
|
|
models_created=result.get('products_trained', 0),
|
|
features=["repository-pattern", "enhanced-tracking"])
|
|
|
|
except Exception as training_error:
|
|
logger.error("Enhanced training pipeline failed",
|
|
job_id=job_id,
|
|
error=str(training_error))
|
|
|
|
try:
|
|
await enhanced_training_service._update_job_status_repository(
|
|
job_id=job_id,
|
|
status="failed",
|
|
progress=0,
|
|
current_step="Enhanced training failed",
|
|
error_message=str(training_error),
|
|
tenant_id=tenant_id
|
|
)
|
|
except Exception as status_error:
|
|
logger.error("Failed to update job status after training error",
|
|
job_id=job_id,
|
|
status_error=str(status_error))
|
|
|
|
# Failure event is published by the training service
|
|
await publish_training_failed(job_id, tenant_id, str(training_error))
|
|
|
|
finally:
|
|
logger.info("Enhanced background training job cleanup completed",
|
|
job_id=job_id)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse)
|
|
@require_user_role(['admin', 'owner'])
|
|
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
|
|
async def start_single_product_training(
|
|
request: SingleProductTrainingRequest,
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
inventory_product_id: str = Path(..., description="Inventory product UUID"),
|
|
background_tasks: BackgroundTasks = BackgroundTasks(),
|
|
request_obj: Request = None,
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
|
):
|
|
"""
|
|
Start enhanced training for a single product (Admin+ only).
|
|
|
|
**RBAC:** Admin or Owner role required
|
|
|
|
Enhanced features:
|
|
- Repository pattern for data access
|
|
- Enhanced error handling and validation
|
|
- Metrics tracking
|
|
- Transactional operations
|
|
- Background execution to prevent blocking
|
|
"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
logger.info("Starting enhanced single product training",
|
|
inventory_product_id=inventory_product_id,
|
|
tenant_id=tenant_id)
|
|
|
|
# CRITICAL FIX: Check if this product is currently being trained
|
|
# This prevents duplicate training from rapid-click scenarios
|
|
async with enhanced_training_service.database_manager.get_session() as check_session:
|
|
from app.repositories.training_log_repository import TrainingLogRepository
|
|
log_repo = TrainingLogRepository(check_session)
|
|
|
|
# Check for active jobs for this specific product
|
|
active_jobs = await log_repo.get_active_jobs(tenant_id=tenant_id)
|
|
pending_jobs = await log_repo.get_logs_by_tenant(
|
|
tenant_id=tenant_id,
|
|
status="pending",
|
|
limit=20
|
|
)
|
|
|
|
all_active = active_jobs + pending_jobs
|
|
|
|
# Filter for jobs that include this specific product
|
|
product_jobs = [
|
|
job for job in all_active
|
|
if job.config and (
|
|
# Single product job for this product
|
|
job.config.get("product_id") == inventory_product_id or
|
|
# Tenant-wide job that would include this product
|
|
job.config.get("job_type") == "tenant_training"
|
|
)
|
|
]
|
|
|
|
if product_jobs:
|
|
existing_job = product_jobs[0]
|
|
logger.warning("Product training already in progress, rejecting duplicate request",
|
|
existing_job_id=existing_job.job_id,
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
status=existing_job.status)
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail={
|
|
"error": "Product training already in progress",
|
|
"message": f"Product {inventory_product_id} is currently being trained in job {existing_job.job_id}",
|
|
"existing_job_id": existing_job.job_id,
|
|
"status": existing_job.status,
|
|
"started_at": existing_job.created_at.isoformat() if existing_job.created_at else None
|
|
}
|
|
)
|
|
|
|
# No existing job, proceed with training
|
|
# Record metrics
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_single_product_training_total")
|
|
|
|
# Generate enhanced job ID
|
|
job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
|
|
|
|
# CRITICAL FIX: Add initial training log entry
|
|
await enhanced_training_service._update_job_status_repository(
|
|
job_id=job_id,
|
|
status="pending",
|
|
progress=0,
|
|
current_step="Initializing single product training",
|
|
tenant_id=tenant_id
|
|
)
|
|
|
|
# Add enhanced background task for single product training
|
|
background_tasks.add_task(
|
|
execute_single_product_training_background,
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
job_id=job_id,
|
|
bakery_location=request.bakery_location or (40.4168, -3.7038),
|
|
database_manager=enhanced_training_service.database_manager
|
|
)
|
|
|
|
# Return immediate response with job info
|
|
response_data = {
|
|
"job_id": job_id,
|
|
"tenant_id": tenant_id,
|
|
"status": "pending",
|
|
"message": "Enhanced single product training started successfully",
|
|
"created_at": datetime.now(timezone.utc),
|
|
"estimated_duration_minutes": 15, # Default estimate for single product
|
|
"training_results": {
|
|
"total_products": 1,
|
|
"successful_trainings": 0,
|
|
"failed_trainings": 0,
|
|
"products": [],
|
|
"overall_training_time_seconds": 0.0
|
|
},
|
|
"data_summary": None,
|
|
"completed_at": None,
|
|
"error_details": None,
|
|
"processing_metadata": {
|
|
"background_task": True,
|
|
"async_execution": True,
|
|
"enhanced_features": True,
|
|
"repository_pattern": True,
|
|
"dependency_injection": True
|
|
}
|
|
}
|
|
|
|
logger.info("Enhanced single product training queued successfully",
|
|
inventory_product_id=inventory_product_id,
|
|
job_id=job_id)
|
|
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_single_product_training_queued_total")
|
|
|
|
return TrainingJobResponse(**response_data)
|
|
|
|
except ValueError as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_single_product_validation_errors_total")
|
|
logger.error("Enhanced single product training validation error",
|
|
error=str(e),
|
|
inventory_product_id=inventory_product_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_single_product_training_errors_total")
|
|
logger.error("Enhanced single product training failed",
|
|
error=str(e),
|
|
inventory_product_id=inventory_product_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Enhanced single product training failed"
|
|
)
|
|
|
|
|
|
async def execute_single_product_training_background(
|
|
tenant_id: str,
|
|
inventory_product_id: str,
|
|
job_id: str,
|
|
bakery_location: tuple,
|
|
database_manager
|
|
):
|
|
"""
|
|
Enhanced background task that executes single product training using repository pattern.
|
|
Uses a separate service instance to avoid session conflicts.
|
|
"""
|
|
logger.info("Enhanced background single product training started",
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id)
|
|
|
|
# Create a new service instance with a fresh database session to avoid conflicts
|
|
from app.services.training_service import EnhancedTrainingService
|
|
fresh_training_service = EnhancedTrainingService(database_manager)
|
|
|
|
try:
|
|
# Update job status to running
|
|
await fresh_training_service._update_job_status_repository(
|
|
job_id=job_id,
|
|
status="running",
|
|
progress=0,
|
|
current_step="Starting single product training",
|
|
tenant_id=tenant_id
|
|
)
|
|
|
|
# Execute the enhanced single product training with repository pattern
|
|
result = await fresh_training_service.start_single_product_training(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
job_id=job_id,
|
|
bakery_location=bakery_location
|
|
)
|
|
|
|
logger.info("Enhanced background single product training completed successfully",
|
|
job_id=job_id,
|
|
inventory_product_id=inventory_product_id)
|
|
|
|
except Exception as training_error:
|
|
logger.error("Enhanced single product training failed",
|
|
job_id=job_id,
|
|
inventory_product_id=inventory_product_id,
|
|
error=str(training_error))
|
|
|
|
try:
|
|
await fresh_training_service._update_job_status_repository(
|
|
job_id=job_id,
|
|
status="failed",
|
|
progress=0,
|
|
current_step="Single product training failed",
|
|
error_message=str(training_error),
|
|
tenant_id=tenant_id
|
|
)
|
|
except Exception as status_error:
|
|
logger.error("Failed to update job status after training error",
|
|
job_id=job_id,
|
|
status_error=str(status_error))
|
|
|
|
finally:
|
|
logger.info("Enhanced background single product training cleanup completed",
|
|
job_id=job_id,
|
|
inventory_product_id=inventory_product_id)
|
|
|
|
|
|
@router.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint for the training operations"""
|
|
return {
|
|
"status": "healthy",
|
|
"service": "training-operations",
|
|
"version": "3.0.0",
|
|
"features": [
|
|
"repository-pattern",
|
|
"dependency-injection",
|
|
"enhanced-error-handling",
|
|
"metrics-tracking",
|
|
"transactional-operations"
|
|
],
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
|
|
# ============================================================================
|
|
# Tenant Data Deletion Operations (Internal Service Only)
|
|
# ============================================================================
|
|
|
|
@router.delete(
|
|
route_builder.build_base_route("tenant/{tenant_id}", include_tenant_prefix=False),
|
|
response_model=dict
|
|
)
|
|
@service_only_access
|
|
async def delete_tenant_data(
|
|
tenant_id: str = Path(..., description="Tenant ID to delete data for"),
|
|
current_user: dict = Depends(get_current_user_dep)
|
|
):
|
|
"""
|
|
Delete all training data for a tenant (Internal service only)
|
|
|
|
This endpoint is called by the orchestrator during tenant deletion.
|
|
It permanently deletes all training-related data including:
|
|
- Trained models (all versions)
|
|
- Model artifacts (files and metadata)
|
|
- Training logs and job history
|
|
- Model performance metrics
|
|
- Training job queue entries
|
|
- Audit logs
|
|
|
|
**WARNING**: This operation is irreversible!
|
|
**NOTE**: Physical model files (.pkl) should be cleaned up separately
|
|
|
|
Returns:
|
|
Deletion summary with counts of deleted records
|
|
"""
|
|
from app.services.tenant_deletion_service import TrainingTenantDeletionService
|
|
from app.core.config import settings
|
|
|
|
try:
|
|
logger.info("training.tenant_deletion.api_called", tenant_id=tenant_id)
|
|
|
|
db_manager = create_database_manager(settings.DATABASE_URL, "training")
|
|
|
|
async with db_manager.get_session() as session:
|
|
deletion_service = TrainingTenantDeletionService(session)
|
|
result = await deletion_service.safe_delete_tenant_data(tenant_id)
|
|
|
|
if not result.success:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Tenant data deletion failed: {', '.join(result.errors)}"
|
|
)
|
|
|
|
return {
|
|
"message": "Tenant data deletion completed successfully",
|
|
"note": "Physical model files should be cleaned up separately from storage",
|
|
"summary": result.to_dict()
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("training.tenant_deletion.api_error",
|
|
tenant_id=tenant_id,
|
|
error=str(e),
|
|
exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to delete tenant data: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get(
|
|
route_builder.build_base_route("tenant/{tenant_id}/deletion-preview", include_tenant_prefix=False),
|
|
response_model=dict
|
|
)
|
|
@service_only_access
|
|
async def preview_tenant_data_deletion(
|
|
tenant_id: str = Path(..., description="Tenant ID to preview deletion for"),
|
|
current_user: dict = Depends(get_current_user_dep)
|
|
):
|
|
"""
|
|
Preview what data would be deleted for a tenant (dry-run)
|
|
|
|
This endpoint shows counts of all data that would be deleted
|
|
without actually deleting anything. Useful for:
|
|
- Confirming deletion scope before execution
|
|
- Auditing and compliance
|
|
- Troubleshooting
|
|
|
|
Returns:
|
|
Dictionary with entity names and their counts
|
|
"""
|
|
from app.services.tenant_deletion_service import TrainingTenantDeletionService
|
|
from app.core.config import settings
|
|
|
|
try:
|
|
logger.info("training.tenant_deletion.preview_called", tenant_id=tenant_id)
|
|
|
|
db_manager = create_database_manager(settings.DATABASE_URL, "training")
|
|
|
|
async with db_manager.get_session() as session:
|
|
deletion_service = TrainingTenantDeletionService(session)
|
|
preview = await deletion_service.get_tenant_data_preview(tenant_id)
|
|
|
|
total_records = sum(preview.values())
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"service": "training",
|
|
"preview": preview,
|
|
"total_records": total_records,
|
|
"note": "Physical model files (.pkl, metadata) are not counted here",
|
|
"warning": "These records will be permanently deleted and cannot be recovered"
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("training.tenant_deletion.preview_error",
|
|
tenant_id=tenant_id,
|
|
error=str(e),
|
|
exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to preview tenant data deletion: {str(e)}"
|
|
)
|