Critical fixes for training session logging:
1. Training log race condition fix:
- Add explicit session commits after creating training logs
- Handle duplicate key errors gracefully when multiple sessions
try to create the same log simultaneously
- Implement retry logic to query for existing logs after
duplicate key violations
- Prevents "Training log not found" errors during training
2. Audit event async generator error fix:
- Replace incorrect next(get_db()) usage with proper
async context manager (database_manager.get_session())
- Fixes "'async_generator' object is not an iterator" error
- Ensures audit logging works correctly
These changes address race conditions in concurrent database
sessions and ensure training logs are properly synchronized
across the training pipeline.
617 lines
23 KiB
Python
617 lines
23 KiB
Python
"""
|
|
Training Operations API - BUSINESS logic
|
|
Handles training job execution and metrics
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Request, Path
|
|
from typing import Optional, Dict, Any
|
|
import structlog
|
|
from datetime import datetime, timezone
|
|
import uuid
|
|
import shared.redis_utils
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from shared.routing import RouteBuilder
|
|
from shared.monitoring.decorators import track_execution_time
|
|
from shared.monitoring.metrics import get_metrics_collector
|
|
from shared.database.base import create_database_manager
|
|
from shared.auth.decorators import get_current_user_dep
|
|
from shared.auth.access_control import require_user_role, admin_role_required, service_only_access
|
|
from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction
|
|
from shared.subscription.plans import (
|
|
get_training_job_quota,
|
|
get_dataset_size_limit
|
|
)
|
|
|
|
from app.services.training_service import EnhancedTrainingService
|
|
from app.schemas.training import (
|
|
TrainingJobRequest,
|
|
SingleProductTrainingRequest,
|
|
TrainingJobResponse
|
|
)
|
|
from app.utils.time_estimation import (
|
|
calculate_initial_estimate,
|
|
calculate_estimated_completion_time,
|
|
get_historical_average_estimate
|
|
)
|
|
from app.services.training_events import (
|
|
publish_training_started,
|
|
publish_training_completed,
|
|
publish_training_failed
|
|
)
|
|
from app.core.config import settings
|
|
from app.core.database import get_db
|
|
from app.models import AuditLog
|
|
|
|
logger = structlog.get_logger()
|
|
route_builder = RouteBuilder('training')
|
|
|
|
router = APIRouter(tags=["training-operations"])
|
|
|
|
# Initialize audit logger
|
|
audit_logger = create_audit_logger("training-service", AuditLog)
|
|
|
|
# Redis client for rate limiting
|
|
_redis_client = None
|
|
|
|
async def get_training_redis_client():
|
|
"""Get or create Redis client for rate limiting"""
|
|
global _redis_client
|
|
if _redis_client is None:
|
|
# Initialize Redis if not already done
|
|
try:
|
|
from app.core.config import settings
|
|
_redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
|
except:
|
|
# Fallback to getting the client directly (if already initialized elsewhere)
|
|
_redis_client = await shared.redis_utils.get_redis_client()
|
|
return _redis_client
|
|
|
|
async def get_rate_limiter():
|
|
"""Dependency for rate limiter"""
|
|
redis_client = await get_training_redis_client()
|
|
return create_rate_limiter(redis_client)
|
|
|
|
def get_enhanced_training_service():
|
|
"""Dependency injection for EnhancedTrainingService"""
|
|
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
|
return EnhancedTrainingService(database_manager)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_base_route("jobs"), response_model=TrainingJobResponse)
|
|
@require_user_role(['admin', 'owner'])
|
|
@track_execution_time("enhanced_training_job_duration_seconds", "training-service")
|
|
async def start_training_job(
|
|
request: TrainingJobRequest,
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
background_tasks: BackgroundTasks = BackgroundTasks(),
|
|
request_obj: Request = None,
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service),
|
|
rate_limiter = Depends(get_rate_limiter),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Start a new training job for all tenant products (Admin+ only, quota enforced).
|
|
|
|
**RBAC:** Admin or Owner role required
|
|
**Quotas:**
|
|
- Starter: 1 training job/day, max 1,000 rows
|
|
- Professional: 5 training jobs/day, max 10,000 rows
|
|
- Enterprise: Unlimited jobs, unlimited rows
|
|
|
|
Enhanced immediate response pattern:
|
|
1. Validate subscription tier and quotas
|
|
2. Validate request with enhanced validation
|
|
3. Create job record using repository pattern
|
|
4. Return 200 with enhanced job details
|
|
5. Execute enhanced training in background with repository tracking
|
|
|
|
Enhanced features:
|
|
- Repository pattern for data access
|
|
- Quota enforcement by subscription tier
|
|
- Audit logging for all operations
|
|
- Enhanced error handling and logging
|
|
- Metrics tracking and monitoring
|
|
- Transactional operations
|
|
"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
# Get subscription tier and enforce quotas
|
|
tier = current_user.get('subscription_tier', 'starter')
|
|
|
|
# Estimate dataset size (this should come from the request or be calculated)
|
|
# For now, we'll assume a reasonable estimate
|
|
estimated_dataset_size = request.estimated_rows if hasattr(request, 'estimated_rows') else 500
|
|
|
|
# Initialize variables for later use
|
|
quota_result = None
|
|
quota_limit = None
|
|
|
|
try:
|
|
# Validate dataset size limits
|
|
await rate_limiter.validate_dataset_size(
|
|
tenant_id, estimated_dataset_size, tier
|
|
)
|
|
|
|
# Check daily training job quota
|
|
quota_limit = get_training_job_quota(tier)
|
|
quota_result = await rate_limiter.check_and_increment_quota(
|
|
tenant_id,
|
|
"training_jobs",
|
|
quota_limit,
|
|
period=86400 # 24 hours
|
|
)
|
|
|
|
logger.info("Training job quota check passed",
|
|
tenant_id=tenant_id,
|
|
tier=tier,
|
|
current_usage=quota_result.get('current', 0) if quota_result else 0,
|
|
limit=quota_limit)
|
|
|
|
except HTTPException:
|
|
# Quota or validation error - re-raise
|
|
raise
|
|
except Exception as quota_error:
|
|
logger.error("Quota validation failed", error=str(quota_error))
|
|
# Continue with job creation but log the error
|
|
|
|
try:
|
|
# Generate enhanced job ID
|
|
job_id = f"enhanced_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
|
|
|
logger.info("Creating enhanced training job using repository pattern",
|
|
job_id=job_id,
|
|
tenant_id=tenant_id)
|
|
|
|
# Record job creation metrics
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_training_jobs_created_total")
|
|
|
|
# Calculate intelligent time estimate
|
|
# We don't know exact product count yet, so use historical average or estimate
|
|
try:
|
|
# Try to get historical average for this tenant
|
|
historical_avg = await get_historical_average_estimate(db, tenant_id)
|
|
|
|
# If no historical data, estimate based on typical product count (10-20 products)
|
|
estimated_products = 15 # Conservative estimate
|
|
estimated_duration_minutes = calculate_initial_estimate(
|
|
total_products=estimated_products,
|
|
avg_training_time_per_product=historical_avg if historical_avg else 60.0
|
|
)
|
|
except Exception as est_error:
|
|
logger.warning("Could not calculate intelligent estimate, using default",
|
|
error=str(est_error))
|
|
estimated_duration_minutes = 15 # Default fallback
|
|
|
|
# Calculate estimated completion time
|
|
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
|
|
|
|
# Note: training.started event will be published by the trainer with accurate product count
|
|
# We don't publish here to avoid duplicate events
|
|
|
|
# Add enhanced background task
|
|
background_tasks.add_task(
|
|
execute_training_job_background,
|
|
tenant_id=tenant_id,
|
|
job_id=job_id,
|
|
bakery_location=(40.4168, -3.7038),
|
|
requested_start=request.start_date,
|
|
requested_end=request.end_date,
|
|
estimated_duration_minutes=estimated_duration_minutes
|
|
)
|
|
|
|
# Return enhanced immediate success response
|
|
response_data = {
|
|
"job_id": job_id,
|
|
"tenant_id": tenant_id,
|
|
"status": "pending",
|
|
"message": "Enhanced training job started successfully using repository pattern",
|
|
"created_at": datetime.now(timezone.utc),
|
|
"estimated_duration_minutes": estimated_duration_minutes,
|
|
"training_results": {
|
|
"total_products": 0,
|
|
"successful_trainings": 0,
|
|
"failed_trainings": 0,
|
|
"products": [],
|
|
"overall_training_time_seconds": 0.0
|
|
},
|
|
"data_summary": None,
|
|
"completed_at": None,
|
|
"error_details": None,
|
|
"processing_metadata": {
|
|
"background_task": True,
|
|
"async_execution": True,
|
|
"enhanced_features": True,
|
|
"repository_pattern": True,
|
|
"dependency_injection": True
|
|
}
|
|
}
|
|
|
|
logger.info("Enhanced training job queued successfully",
|
|
job_id=job_id,
|
|
features=["repository-pattern", "dependency-injection", "enhanced-tracking"])
|
|
|
|
# Log audit event for training job creation
|
|
try:
|
|
from app.core.database import database_manager
|
|
async with database_manager.get_session() as db:
|
|
await audit_logger.log_event(
|
|
db_session=db,
|
|
tenant_id=tenant_id,
|
|
user_id=current_user["user_id"],
|
|
action=AuditAction.CREATE.value,
|
|
resource_type="training_job",
|
|
resource_id=job_id,
|
|
severity=AuditSeverity.MEDIUM.value,
|
|
description=f"Started training job (tier: {tier})",
|
|
metadata={
|
|
"job_id": job_id,
|
|
"tier": tier,
|
|
"estimated_dataset_size": estimated_dataset_size,
|
|
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
|
|
"quota_limit": quota_limit if quota_limit else "unlimited"
|
|
},
|
|
endpoint="/jobs",
|
|
method="POST"
|
|
)
|
|
except Exception as audit_error:
|
|
logger.warning("Failed to log audit event", error=str(audit_error))
|
|
|
|
return TrainingJobResponse(**response_data)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except ValueError as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_training_validation_errors_total")
|
|
logger.error("Enhanced training job validation error",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_training_job_errors_total")
|
|
logger.error("Failed to queue enhanced training job",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to start enhanced training job"
|
|
)
|
|
|
|
|
|
async def execute_training_job_background(
|
|
tenant_id: str,
|
|
job_id: str,
|
|
bakery_location: tuple,
|
|
requested_start: Optional[datetime] = None,
|
|
requested_end: Optional[datetime] = None,
|
|
estimated_duration_minutes: int = 15
|
|
):
|
|
"""
|
|
Enhanced background task that executes the training job using repository pattern.
|
|
|
|
Enhanced features:
|
|
- Repository pattern for all data operations
|
|
- Enhanced error handling with structured logging
|
|
- Transactional operations for data consistency
|
|
- Comprehensive metrics tracking
|
|
- Database connection pooling
|
|
- Enhanced progress reporting
|
|
"""
|
|
|
|
logger.info("Enhanced background training job started",
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
features=["repository-pattern", "enhanced-tracking"])
|
|
|
|
# Get enhanced training service with dependency injection
|
|
database_manager = create_database_manager(settings.DATABASE_URL, "training-service")
|
|
enhanced_training_service = EnhancedTrainingService(database_manager)
|
|
|
|
try:
|
|
# Create initial training log entry first
|
|
await enhanced_training_service._update_job_status_repository(
|
|
job_id=job_id,
|
|
status="pending",
|
|
progress=0,
|
|
current_step="Starting enhanced training job",
|
|
tenant_id=tenant_id
|
|
)
|
|
|
|
# This will be published by the training service itself
|
|
# when it starts execution
|
|
|
|
training_config = {
|
|
"job_id": job_id,
|
|
"tenant_id": tenant_id,
|
|
"bakery_location": {
|
|
"latitude": bakery_location[0],
|
|
"longitude": bakery_location[1]
|
|
},
|
|
"requested_start": requested_start.isoformat() if requested_start else None,
|
|
"requested_end": requested_end.isoformat() if requested_end else None,
|
|
"estimated_duration_minutes": estimated_duration_minutes,
|
|
"background_execution": True,
|
|
"enhanced_features": True,
|
|
"repository_pattern": True,
|
|
"api_version": "enhanced_v1"
|
|
}
|
|
|
|
# Update job status using repository pattern
|
|
await enhanced_training_service._update_job_status_repository(
|
|
job_id=job_id,
|
|
status="running",
|
|
progress=0,
|
|
current_step="Initializing enhanced training pipeline",
|
|
tenant_id=tenant_id
|
|
)
|
|
|
|
# Execute the enhanced training pipeline with repository pattern
|
|
result = await enhanced_training_service.start_training_job(
|
|
tenant_id=tenant_id,
|
|
job_id=job_id,
|
|
bakery_location=bakery_location,
|
|
requested_start=requested_start,
|
|
requested_end=requested_end
|
|
)
|
|
|
|
# Note: Final status is already updated by start_training_job() via complete_training_log()
|
|
# No need for redundant update here - it was causing duplicate log entries
|
|
|
|
# Completion event is published by the training service
|
|
|
|
logger.info("Enhanced background training job completed successfully",
|
|
job_id=job_id,
|
|
models_created=result.get('products_trained', 0),
|
|
features=["repository-pattern", "enhanced-tracking"])
|
|
|
|
except Exception as training_error:
|
|
logger.error("Enhanced training pipeline failed",
|
|
job_id=job_id,
|
|
error=str(training_error))
|
|
|
|
try:
|
|
await enhanced_training_service._update_job_status_repository(
|
|
job_id=job_id,
|
|
status="failed",
|
|
progress=0,
|
|
current_step="Enhanced training failed",
|
|
error_message=str(training_error),
|
|
tenant_id=tenant_id
|
|
)
|
|
except Exception as status_error:
|
|
logger.error("Failed to update job status after training error",
|
|
job_id=job_id,
|
|
status_error=str(status_error))
|
|
|
|
# Failure event is published by the training service
|
|
await publish_training_failed(job_id, tenant_id, str(training_error))
|
|
|
|
finally:
|
|
logger.info("Enhanced background training job cleanup completed",
|
|
job_id=job_id)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_resource_detail_route("products", "inventory_product_id"), response_model=TrainingJobResponse)
|
|
@require_user_role(['admin', 'owner'])
|
|
@track_execution_time("enhanced_single_product_training_duration_seconds", "training-service")
|
|
async def start_single_product_training(
|
|
request: SingleProductTrainingRequest,
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
inventory_product_id: str = Path(..., description="Inventory product UUID"),
|
|
request_obj: Request = None,
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
enhanced_training_service: EnhancedTrainingService = Depends(get_enhanced_training_service)
|
|
):
|
|
"""
|
|
Start enhanced training for a single product (Admin+ only).
|
|
|
|
**RBAC:** Admin or Owner role required
|
|
|
|
Enhanced features:
|
|
- Repository pattern for data access
|
|
- Enhanced error handling and validation
|
|
- Metrics tracking
|
|
- Transactional operations
|
|
"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
logger.info("Starting enhanced single product training",
|
|
inventory_product_id=inventory_product_id,
|
|
tenant_id=tenant_id)
|
|
|
|
# Record metrics
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_single_product_training_total")
|
|
|
|
# Generate enhanced job ID
|
|
job_id = f"enhanced_single_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
|
|
|
|
# Delegate to enhanced training service
|
|
result = await enhanced_training_service.start_single_product_training(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
job_id=job_id,
|
|
bakery_location=request.bakery_location or (40.4168, -3.7038)
|
|
)
|
|
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_single_product_training_success_total")
|
|
|
|
logger.info("Enhanced single product training completed",
|
|
inventory_product_id=inventory_product_id,
|
|
job_id=job_id)
|
|
|
|
return TrainingJobResponse(**result)
|
|
|
|
except ValueError as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_single_product_validation_errors_total")
|
|
logger.error("Enhanced single product training validation error",
|
|
error=str(e),
|
|
inventory_product_id=inventory_product_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("enhanced_single_product_training_errors_total")
|
|
logger.error("Enhanced single product training failed",
|
|
error=str(e),
|
|
inventory_product_id=inventory_product_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Enhanced single product training failed"
|
|
)
|
|
|
|
|
|
@router.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint for the training operations"""
|
|
return {
|
|
"status": "healthy",
|
|
"service": "training-operations",
|
|
"version": "3.0.0",
|
|
"features": [
|
|
"repository-pattern",
|
|
"dependency-injection",
|
|
"enhanced-error-handling",
|
|
"metrics-tracking",
|
|
"transactional-operations"
|
|
],
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
|
|
# ============================================================================
|
|
# Tenant Data Deletion Operations (Internal Service Only)
|
|
# ============================================================================
|
|
|
|
@router.delete(
|
|
route_builder.build_base_route("tenant/{tenant_id}", include_tenant_prefix=False),
|
|
response_model=dict
|
|
)
|
|
@service_only_access
|
|
async def delete_tenant_data(
|
|
tenant_id: str = Path(..., description="Tenant ID to delete data for"),
|
|
current_user: dict = Depends(get_current_user_dep)
|
|
):
|
|
"""
|
|
Delete all training data for a tenant (Internal service only)
|
|
|
|
This endpoint is called by the orchestrator during tenant deletion.
|
|
It permanently deletes all training-related data including:
|
|
- Trained models (all versions)
|
|
- Model artifacts (files and metadata)
|
|
- Training logs and job history
|
|
- Model performance metrics
|
|
- Training job queue entries
|
|
- Audit logs
|
|
|
|
**WARNING**: This operation is irreversible!
|
|
**NOTE**: Physical model files (.pkl) should be cleaned up separately
|
|
|
|
Returns:
|
|
Deletion summary with counts of deleted records
|
|
"""
|
|
from app.services.tenant_deletion_service import TrainingTenantDeletionService
|
|
from app.core.config import settings
|
|
|
|
try:
|
|
logger.info("training.tenant_deletion.api_called", tenant_id=tenant_id)
|
|
|
|
db_manager = create_database_manager(settings.DATABASE_URL, "training")
|
|
|
|
async with db_manager.get_session() as session:
|
|
deletion_service = TrainingTenantDeletionService(session)
|
|
result = await deletion_service.safe_delete_tenant_data(tenant_id)
|
|
|
|
if not result.success:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Tenant data deletion failed: {', '.join(result.errors)}"
|
|
)
|
|
|
|
return {
|
|
"message": "Tenant data deletion completed successfully",
|
|
"note": "Physical model files should be cleaned up separately from storage",
|
|
"summary": result.to_dict()
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("training.tenant_deletion.api_error",
|
|
tenant_id=tenant_id,
|
|
error=str(e),
|
|
exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to delete tenant data: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get(
|
|
route_builder.build_base_route("tenant/{tenant_id}/deletion-preview", include_tenant_prefix=False),
|
|
response_model=dict
|
|
)
|
|
@service_only_access
|
|
async def preview_tenant_data_deletion(
|
|
tenant_id: str = Path(..., description="Tenant ID to preview deletion for"),
|
|
current_user: dict = Depends(get_current_user_dep)
|
|
):
|
|
"""
|
|
Preview what data would be deleted for a tenant (dry-run)
|
|
|
|
This endpoint shows counts of all data that would be deleted
|
|
without actually deleting anything. Useful for:
|
|
- Confirming deletion scope before execution
|
|
- Auditing and compliance
|
|
- Troubleshooting
|
|
|
|
Returns:
|
|
Dictionary with entity names and their counts
|
|
"""
|
|
from app.services.tenant_deletion_service import TrainingTenantDeletionService
|
|
from app.core.config import settings
|
|
|
|
try:
|
|
logger.info("training.tenant_deletion.preview_called", tenant_id=tenant_id)
|
|
|
|
db_manager = create_database_manager(settings.DATABASE_URL, "training")
|
|
|
|
async with db_manager.get_session() as session:
|
|
deletion_service = TrainingTenantDeletionService(session)
|
|
preview = await deletion_service.get_tenant_data_preview(tenant_id)
|
|
|
|
total_records = sum(preview.values())
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"service": "training",
|
|
"preview": preview,
|
|
"total_records": total_records,
|
|
"note": "Physical model files (.pkl, metadata) are not counted here",
|
|
"warning": "These records will be permanently deleted and cannot be recovered"
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("training.tenant_deletion.preview_error",
|
|
tenant_id=tenant_id,
|
|
error=str(e),
|
|
exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to preview tenant data deletion: {str(e)}"
|
|
)
|