Files
bakery-ia/services/training/app/api/training.py
2025-08-04 18:21:42 +02:00

521 lines
18 KiB
Python

# services/training/app/api/training.py
"""
Training API Endpoints - Entry point for training requests
Handles HTTP requests and delegates to Training Service
"""
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
from fastapi import Query, Path
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional, Dict, Any
import structlog
from datetime import datetime, timezone
import uuid
from app.core.database import get_db, get_background_db_session
from app.services.training_service import TrainingService, TrainingStatusManager
from sqlalchemy import select, delete, func
from app.schemas.training import (
TrainingJobRequest,
SingleProductTrainingRequest
)
from app.schemas.training import (
TrainingJobResponse
)
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed,
publish_job_completed,
publish_job_failed,
publish_job_started
)
from shared.auth.decorators import require_admin_role, get_current_user_dep, get_current_tenant_id_dep
logger = structlog.get_logger()
router = APIRouter()
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
async def start_training_job(
request: TrainingJobRequest,
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
):
"""
Start a new training job for all tenant products.
🚀 IMMEDIATE RESPONSE PATTERN:
1. Validate request immediately
2. Create job record with 'pending' status
3. Return 200 with job details
4. Execute training in background with separate DB session
This ensures fast API response while maintaining data consistency.
"""
try:
# Validate tenant access immediately
if tenant_id != current_tenant:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Generate job ID immediately
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info(f"Creating training job {job_id} for tenant {tenant_id}")
# Add background task with isolated database session
background_tasks.add_task(
execute_training_job_background,
tenant_id=tenant_id,
job_id=job_id,
bakery_location=(40.4168, -3.7038),
requested_start=request.start_date,
requested_end=request.end_date
)
# Return immediate success response
response_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending", # Will change to 'running' in background
"message": "Training job started successfully",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": "15",
"training_results": {
"total_products": 10,
"successful_trainings": 0,
"failed_trainings": 0,
"products": [],
"overall_training_time_seconds": 0.0
},
"data_summary": None,
"completed_at": None,
"error_details": None,
"processing_metadata": {
"background_task": True,
"async_execution": True
}
}
logger.info(f"Training job {job_id} queued successfully, returning immediate response")
return TrainingJobResponse(**response_data)
except HTTPException:
# Re-raise HTTP exceptions as-is
raise
except ValueError as e:
logger.error(f"Training job validation error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Failed to queue training job: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to start training job"
)
async def execute_training_job_background(
tenant_id: str,
job_id: str,
bakery_location: tuple,
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None
):
"""
Background task that executes the actual training job.
🔧 KEY FEATURES:
- Uses its own database session (isolated from API request)
- Handles all errors gracefully
- Updates job status in real-time
- Publishes progress events via WebSocket/messaging
- Comprehensive logging and monitoring
"""
logger.info(f"🚀 Background training job {job_id} started for tenant {tenant_id}")
async with get_background_db_session() as db_session:
try:
# ✅ FIX: Create training service with isolated DB session
training_service = TrainingService(db_session=db_session)
status_manager = TrainingStatusManager(db_session=db_session)
try:
training_config = {
"job_id": job_id,
"tenant_id": tenant_id,
"bakery_location": {
"latitude": 40.4168,
"longitude": -3.7038
},
"requested_start": requested_start if requested_start else None,
"requested_end": requested_end if requested_end else None,
"estimated_duration_minutes": 15,
"estimated_products": None,
"background_execution": True,
"api_version": "v1"
}
# Publish immediate event (training started)
await publish_job_started(
job_id=job_id,
tenant_id=tenant_id,
config=training_config
)
await status_manager.update_job_status(
job_id=job_id,
status="running",
progress=0,
current_step="Initializing training pipeline"
)
# Execute the actual training pipeline
result = await training_service.start_training_job(
tenant_id=tenant_id,
job_id=job_id,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end
)
await status_manager.update_job_status(
job_id=job_id,
status="completed",
progress=100,
current_step="Training completed successfully",
results=result
)
# Publish completion event
await publish_job_completed(
job_id=job_id,
tenant_id=tenant_id,
results=result
)
logger.info(f"✅ Background training job {job_id} completed successfully")
except Exception as training_error:
logger.error(f"❌ Training pipeline failed for job {job_id}: {str(training_error)}")
await status_manager.update_job_status(
job_id=job_id,
status="failed",
progress=0,
current_step="Training failed",
error_message=str(training_error)
)
# Publish failure event
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error=str(training_error)
)
except Exception as background_error:
logger.error(f"💥 Critical error in background training job {job_id}: {str(background_error)}")
finally:
# Ensure database session is properly closed
logger.info(f"🧹 Background training job {job_id} cleanup completed")
@router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse)
async def start_single_product_training(
request: SingleProductTrainingRequest,
tenant_id: str = Path(..., description="Tenant ID"),
product_name: str = Path(..., description="Product name"),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
):
"""
Start training for a single product.
Uses the same pipeline but filters for specific product.
"""
training_service = TrainingService(db_session=db)
try:
# Validate tenant access
if tenant_id != current_tenant:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
logger.info(f"Starting single product training for {product_name} (tenant {tenant_id})")
# Delegate to training service
result = await training_service.start_single_product_training(
tenant_id=tenant_id,
product_name=product_name,
sales_data=request.sales_data,
bakery_location=request.bakery_location or (40.4168, -3.7038),
weather_data=request.weather_data,
traffic_data=request.traffic_data,
job_id=request.job_id
)
return TrainingJobResponse(**result)
except ValueError as e:
logger.error(f"Single product training validation error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Single product training failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Single product training failed"
)
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/logs")
async def get_training_logs(
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
limit: int = Query(100, description="Number of log entries to return"),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get training job logs.
"""
try:
# Validate tenant access
if tenant_id != current_tenant:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# TODO: Implement log retrieval
return {
"job_id": job_id,
"logs": [
f"Training job {job_id} started",
"Data preprocessing completed",
"Model training completed",
"Training job finished successfully"
]
}
except Exception as e:
logger.error(f"Failed to get training logs: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training logs"
)
@router.get("/health")
async def health_check():
"""
Health check endpoint for the training service.
"""
return {
"status": "healthy",
"service": "training",
"version": "1.0.0",
"timestamp": datetime.now().isoformat()
}
@router.post("/tenants/{tenant_id}/training/jobs/cancel")
async def cancel_tenant_training_jobs(
cancel_data: dict, # {"tenant_id": str}
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Cancel all active training jobs for a tenant (admin only)"""
try:
tenant_id = cancel_data.get("tenant_id")
if not tenant_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="tenant_id is required"
)
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import TrainingJobQueue
# Find all active jobs for the tenant
active_jobs_query = select(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_uuid,
TrainingJobQueue.status.in_(["queued", "running", "pending"])
)
active_jobs_result = await db.execute(active_jobs_query)
active_jobs = active_jobs_result.scalars().all()
jobs_cancelled = 0
cancelled_job_ids = []
errors = []
for job in active_jobs:
try:
job.status = "cancelled"
job.updated_at = datetime.utcnow()
job.cancelled_by = current_user.get("user_id")
jobs_cancelled += 1
cancelled_job_ids.append(str(job.id))
logger.info("Cancelled training job",
job_id=str(job.id),
tenant_id=tenant_id)
except Exception as e:
error_msg = f"Failed to cancel job {job.id}: {str(e)}"
errors.append(error_msg)
logger.error(error_msg)
if jobs_cancelled > 0:
await db.commit()
result = {
"success": True,
"tenant_id": tenant_id,
"jobs_cancelled": jobs_cancelled,
"cancelled_job_ids": cancelled_job_ids,
"errors": errors,
"cancelled_at": datetime.utcnow().isoformat()
}
if errors:
result["success"] = len(errors) < len(active_jobs)
return result
except Exception as e:
await db.rollback()
logger.error("Failed to cancel tenant training jobs",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cancel training jobs"
)
@router.get("/tenants/{tenant_id}/training/jobs/active")
async def get_tenant_active_jobs(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Get all active training jobs for a tenant (admin only)"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import TrainingJobQueue
# Get active jobs
active_jobs_query = select(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_uuid,
TrainingJobQueue.status.in_(["queued", "running", "pending"])
)
active_jobs_result = await db.execute(active_jobs_query)
active_jobs = active_jobs_result.scalars().all()
jobs = []
for job in active_jobs:
jobs.append({
"id": str(job.id),
"tenant_id": str(job.tenant_id),
"status": job.status,
"created_at": job.created_at.isoformat() if job.created_at else None,
"updated_at": job.updated_at.isoformat() if job.updated_at else None,
"started_at": job.started_at.isoformat() if job.started_at else None,
"progress": getattr(job, 'progress', 0)
})
return {
"tenant_id": tenant_id,
"active_jobs_count": len(jobs),
"jobs": jobs
}
except Exception as e:
logger.error("Failed to get tenant active jobs",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get active jobs"
)
@router.get("/tenants/{tenant_id}/training/jobs/count")
async def get_tenant_models_count(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Get count of trained models for a tenant (admin only)"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import TrainedModel, ModelArtifact
# Count models
models_count_query = select(func.count(TrainedModel.id)).where(
TrainedModel.tenant_id == tenant_uuid
)
models_count_result = await db.execute(models_count_query)
models_count = models_count_result.scalar()
# Count artifacts
artifacts_count_query = select(func.count(ModelArtifact.id)).where(
ModelArtifact.tenant_id == tenant_uuid
)
artifacts_count_result = await db.execute(artifacts_count_query)
artifacts_count = artifacts_count_result.scalar()
return {
"tenant_id": tenant_id,
"models_count": models_count,
"artifacts_count": artifacts_count,
"total_training_assets": models_count + artifacts_count
}
except Exception as e:
logger.error("Failed to get tenant models count",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get models count"
)