524 lines
18 KiB
Python
524 lines
18 KiB
Python
# services/training/app/api/training.py
|
|
"""
|
|
Training API Endpoints - Entry point for training requests
|
|
Handles HTTP requests and delegates to Training Service
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
|
|
from fastapi import Query, Path
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from typing import List, Optional, Dict, Any
|
|
import structlog
|
|
from datetime import datetime, timezone
|
|
import uuid
|
|
|
|
from app.core.database import get_db, get_background_db_session
|
|
from app.services.training_service import TrainingService, TrainingStatusManager
|
|
from sqlalchemy import select, delete, func
|
|
from app.schemas.training import (
|
|
TrainingJobRequest,
|
|
SingleProductTrainingRequest
|
|
)
|
|
from app.schemas.training import (
|
|
TrainingJobResponse
|
|
)
|
|
|
|
from app.services.messaging import (
|
|
publish_job_progress,
|
|
publish_data_validation_started,
|
|
publish_data_validation_completed,
|
|
publish_job_step_completed,
|
|
publish_job_completed,
|
|
publish_job_failed,
|
|
publish_job_started
|
|
)
|
|
|
|
|
|
from shared.auth.decorators import require_admin_role, get_current_user_dep, get_current_tenant_id_dep
|
|
|
|
logger = structlog.get_logger()
|
|
router = APIRouter()
|
|
|
|
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
|
|
async def start_training_job(
|
|
request: TrainingJobRequest,
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
background_tasks: BackgroundTasks = BackgroundTasks(),
|
|
current_tenant: str = Depends(get_current_tenant_id_dep),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Start a new training job for all tenant products.
|
|
|
|
🚀 IMMEDIATE RESPONSE PATTERN:
|
|
1. Validate request immediately
|
|
2. Create job record with 'pending' status
|
|
3. Return 200 with job details
|
|
4. Execute training in background with separate DB session
|
|
|
|
This ensures fast API response while maintaining data consistency.
|
|
"""
|
|
try:
|
|
# Validate tenant access immediately
|
|
if tenant_id != current_tenant:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Access denied to tenant resources"
|
|
)
|
|
|
|
# Generate job ID immediately
|
|
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
|
|
|
logger.info(f"Creating training job {job_id} for tenant {tenant_id}")
|
|
|
|
# Add background task with isolated database session
|
|
background_tasks.add_task(
|
|
execute_training_job_background,
|
|
tenant_id=tenant_id,
|
|
job_id=job_id,
|
|
bakery_location=(40.4168, -3.7038),
|
|
requested_start=request.start_date,
|
|
requested_end=request.end_date
|
|
)
|
|
|
|
training_config = {
|
|
"job_id": job_id,
|
|
"tenant_id": tenant_id,
|
|
"bakery_location": {
|
|
"latitude": 40.4168,
|
|
"longitude": -3.7038
|
|
},
|
|
"requested_start": request.start_date.isoformat() if request.start_date else None,
|
|
"requested_end": request.end_date.isoformat() if request.end_date else None,
|
|
"estimated_duration_minutes": 15,
|
|
"estimated_products": 10,
|
|
"background_execution": True,
|
|
"api_version": "v1"
|
|
}
|
|
|
|
# Publish immediate event (training started)
|
|
await publish_job_started(
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
config=training_config
|
|
)
|
|
|
|
# Return immediate success response
|
|
response_data = {
|
|
"job_id": job_id,
|
|
"tenant_id": tenant_id,
|
|
"status": "pending", # Will change to 'running' in background
|
|
"message": "Training job started successfully",
|
|
"created_at": datetime.now(timezone.utc),
|
|
"estimated_duration_minutes": "15",
|
|
"training_results": {
|
|
"total_products": 10,
|
|
"successful_trainings": 0,
|
|
"failed_trainings": 0,
|
|
"products": [],
|
|
"overall_training_time_seconds": 0.0
|
|
},
|
|
"data_summary": None,
|
|
"completed_at": None,
|
|
"error_details": None,
|
|
"processing_metadata": {
|
|
"background_task": True,
|
|
"async_execution": True
|
|
}
|
|
}
|
|
|
|
logger.info(f"Training job {job_id} queued successfully, returning immediate response")
|
|
return TrainingJobResponse(**response_data)
|
|
|
|
except HTTPException:
|
|
# Re-raise HTTP exceptions as-is
|
|
raise
|
|
except ValueError as e:
|
|
logger.error(f"Training job validation error: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to queue training job: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to start training job"
|
|
)
|
|
|
|
|
|
async def execute_training_job_background(
|
|
tenant_id: str,
|
|
job_id: str,
|
|
bakery_location: tuple,
|
|
requested_start: Optional[datetime] = None,
|
|
requested_end: Optional[datetime] = None
|
|
):
|
|
"""
|
|
Background task that executes the actual training job.
|
|
|
|
🔧 KEY FEATURES:
|
|
- Uses its own database session (isolated from API request)
|
|
- Handles all errors gracefully
|
|
- Updates job status in real-time
|
|
- Publishes progress events via WebSocket/messaging
|
|
- Comprehensive logging and monitoring
|
|
"""
|
|
|
|
logger.info(f"🚀 Background training job {job_id} started for tenant {tenant_id}")
|
|
|
|
async with get_background_db_session() as db_session:
|
|
try:
|
|
# ✅ FIX: Create training service with isolated DB session
|
|
training_service = TrainingService(db_session=db_session)
|
|
|
|
status_manager = TrainingStatusManager(db_session=db_session)
|
|
|
|
# Publish progress event
|
|
await publish_job_progress(job_id, tenant_id, 5, "Initializing training pipeline")
|
|
|
|
try:
|
|
|
|
await status_manager.update_job_status(
|
|
job_id=job_id,
|
|
status="running",
|
|
progress=0,
|
|
current_step="Initializing training pipeline"
|
|
)
|
|
|
|
# Execute the actual training pipeline
|
|
result = await training_service.start_training_job(
|
|
tenant_id=tenant_id,
|
|
job_id=job_id,
|
|
bakery_location=bakery_location,
|
|
requested_start=requested_start,
|
|
requested_end=requested_end
|
|
)
|
|
|
|
await status_manager.update_job_status(
|
|
job_id=job_id,
|
|
status="completed",
|
|
progress=100,
|
|
current_step="Training completed successfully",
|
|
results=result
|
|
)
|
|
|
|
# Publish completion event
|
|
await publish_job_completed(
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
results=result
|
|
)
|
|
|
|
logger.info(f"✅ Background training job {job_id} completed successfully")
|
|
|
|
except Exception as training_error:
|
|
logger.error(f"❌ Training pipeline failed for job {job_id}: {str(training_error)}")
|
|
|
|
await status_manager.update_job_status(
|
|
job_id=job_id,
|
|
status="failed",
|
|
progress=0,
|
|
current_step="Training failed",
|
|
error_message=str(training_error)
|
|
)
|
|
|
|
# Publish failure event
|
|
await publish_job_failed(
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
error=str(training_error)
|
|
)
|
|
|
|
except Exception as background_error:
|
|
logger.error(f"💥 Critical error in background training job {job_id}: {str(background_error)}")
|
|
|
|
finally:
|
|
# Ensure database session is properly closed
|
|
logger.info(f"🧹 Background training job {job_id} cleanup completed")
|
|
|
|
@router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse)
|
|
async def start_single_product_training(
|
|
request: SingleProductTrainingRequest,
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
product_name: str = Path(..., description="Product name"),
|
|
current_tenant: str = Depends(get_current_tenant_id_dep),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Start training for a single product.
|
|
|
|
Uses the same pipeline but filters for specific product.
|
|
"""
|
|
|
|
training_service = TrainingService(db_session=db)
|
|
|
|
try:
|
|
# Validate tenant access
|
|
if tenant_id != current_tenant:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Access denied to tenant resources"
|
|
)
|
|
|
|
logger.info(f"Starting single product training for {product_name} (tenant {tenant_id})")
|
|
|
|
# Delegate to training service
|
|
result = await training_service.start_single_product_training(
|
|
tenant_id=tenant_id,
|
|
product_name=product_name,
|
|
sales_data=request.sales_data,
|
|
bakery_location=request.bakery_location or (40.4168, -3.7038),
|
|
weather_data=request.weather_data,
|
|
traffic_data=request.traffic_data,
|
|
job_id=request.job_id
|
|
)
|
|
|
|
return TrainingJobResponse(**result)
|
|
|
|
except ValueError as e:
|
|
logger.error(f"Single product training validation error: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Single product training failed: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Single product training failed"
|
|
)
|
|
|
|
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/logs")
|
|
async def get_training_logs(
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
job_id: str = Path(..., description="Job ID"),
|
|
limit: int = Query(100, description="Number of log entries to return"),
|
|
current_tenant: str = Depends(get_current_tenant_id_dep),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Get training job logs.
|
|
"""
|
|
try:
|
|
# Validate tenant access
|
|
if tenant_id != current_tenant:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Access denied to tenant resources"
|
|
)
|
|
|
|
# TODO: Implement log retrieval
|
|
return {
|
|
"job_id": job_id,
|
|
"logs": [
|
|
f"Training job {job_id} started",
|
|
"Data preprocessing completed",
|
|
"Model training completed",
|
|
"Training job finished successfully"
|
|
]
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get training logs: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to get training logs"
|
|
)
|
|
|
|
@router.get("/health")
|
|
async def health_check():
|
|
"""
|
|
Health check endpoint for the training service.
|
|
"""
|
|
return {
|
|
"status": "healthy",
|
|
"service": "training",
|
|
"version": "1.0.0",
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
@router.post("/tenants/{tenant_id}/training/jobs/cancel")
|
|
async def cancel_tenant_training_jobs(
|
|
cancel_data: dict, # {"tenant_id": str}
|
|
current_user = Depends(get_current_user_dep),
|
|
_admin_check = Depends(require_admin_role),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Cancel all active training jobs for a tenant (admin only)"""
|
|
try:
|
|
tenant_id = cancel_data.get("tenant_id")
|
|
if not tenant_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="tenant_id is required"
|
|
)
|
|
|
|
tenant_uuid = uuid.UUID(tenant_id)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid tenant ID format"
|
|
)
|
|
|
|
try:
|
|
from app.models.training import TrainingJobQueue
|
|
|
|
# Find all active jobs for the tenant
|
|
active_jobs_query = select(TrainingJobQueue).where(
|
|
TrainingJobQueue.tenant_id == tenant_uuid,
|
|
TrainingJobQueue.status.in_(["queued", "running", "pending"])
|
|
)
|
|
active_jobs_result = await db.execute(active_jobs_query)
|
|
active_jobs = active_jobs_result.scalars().all()
|
|
|
|
jobs_cancelled = 0
|
|
cancelled_job_ids = []
|
|
errors = []
|
|
|
|
for job in active_jobs:
|
|
try:
|
|
job.status = "cancelled"
|
|
job.updated_at = datetime.utcnow()
|
|
job.cancelled_by = current_user.get("user_id")
|
|
jobs_cancelled += 1
|
|
cancelled_job_ids.append(str(job.id))
|
|
|
|
logger.info("Cancelled training job",
|
|
job_id=str(job.id),
|
|
tenant_id=tenant_id)
|
|
|
|
except Exception as e:
|
|
error_msg = f"Failed to cancel job {job.id}: {str(e)}"
|
|
errors.append(error_msg)
|
|
logger.error(error_msg)
|
|
|
|
if jobs_cancelled > 0:
|
|
await db.commit()
|
|
|
|
result = {
|
|
"success": True,
|
|
"tenant_id": tenant_id,
|
|
"jobs_cancelled": jobs_cancelled,
|
|
"cancelled_job_ids": cancelled_job_ids,
|
|
"errors": errors,
|
|
"cancelled_at": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
if errors:
|
|
result["success"] = len(errors) < len(active_jobs)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error("Failed to cancel tenant training jobs",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to cancel training jobs"
|
|
)
|
|
|
|
@router.get("/tenants/{tenant_id}/training/jobs/active")
|
|
async def get_tenant_active_jobs(
|
|
tenant_id: str,
|
|
current_user = Depends(get_current_user_dep),
|
|
_admin_check = Depends(require_admin_role),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Get all active training jobs for a tenant (admin only)"""
|
|
try:
|
|
tenant_uuid = uuid.UUID(tenant_id)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid tenant ID format"
|
|
)
|
|
|
|
try:
|
|
from app.models.training import TrainingJobQueue
|
|
|
|
# Get active jobs
|
|
active_jobs_query = select(TrainingJobQueue).where(
|
|
TrainingJobQueue.tenant_id == tenant_uuid,
|
|
TrainingJobQueue.status.in_(["queued", "running", "pending"])
|
|
)
|
|
active_jobs_result = await db.execute(active_jobs_query)
|
|
active_jobs = active_jobs_result.scalars().all()
|
|
|
|
jobs = []
|
|
for job in active_jobs:
|
|
jobs.append({
|
|
"id": str(job.id),
|
|
"tenant_id": str(job.tenant_id),
|
|
"status": job.status,
|
|
"created_at": job.created_at.isoformat() if job.created_at else None,
|
|
"updated_at": job.updated_at.isoformat() if job.updated_at else None,
|
|
"started_at": job.started_at.isoformat() if job.started_at else None,
|
|
"progress": getattr(job, 'progress', 0)
|
|
})
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"active_jobs_count": len(jobs),
|
|
"jobs": jobs
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get tenant active jobs",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to get active jobs"
|
|
)
|
|
|
|
@router.get("/tenants/{tenant_id}/training/jobs/count")
|
|
async def get_tenant_models_count(
|
|
tenant_id: str,
|
|
current_user = Depends(get_current_user_dep),
|
|
_admin_check = Depends(require_admin_role),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Get count of trained models for a tenant (admin only)"""
|
|
try:
|
|
tenant_uuid = uuid.UUID(tenant_id)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid tenant ID format"
|
|
)
|
|
|
|
try:
|
|
from app.models.training import TrainedModel, ModelArtifact
|
|
|
|
# Count models
|
|
models_count_query = select(func.count(TrainedModel.id)).where(
|
|
TrainedModel.tenant_id == tenant_uuid
|
|
)
|
|
models_count_result = await db.execute(models_count_query)
|
|
models_count = models_count_result.scalar()
|
|
|
|
# Count artifacts
|
|
artifacts_count_query = select(func.count(ModelArtifact.id)).where(
|
|
ModelArtifact.tenant_id == tenant_uuid
|
|
)
|
|
artifacts_count_result = await db.execute(artifacts_count_query)
|
|
artifacts_count = artifacts_count_result.scalar()
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"models_count": models_count,
|
|
"artifacts_count": artifacts_count,
|
|
"total_training_assets": models_count + artifacts_count
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get tenant models count",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to get models count"
|
|
) |