Files
bakery-ia/services/training/app/api/training.py
2025-07-19 16:59:37 +02:00

299 lines
9.8 KiB
Python

# services/training/app/api/training.py
"""
Training API endpoints for the training service
"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Dict, List, Any, Optional
import logging
from datetime import datetime
import uuid
from app.core.database import get_db
from app.core.auth import get_current_tenant_id
from app.schemas.training import (
TrainingJobRequest,
TrainingJobResponse,
TrainingStatusResponse,
SingleProductTrainingRequest
)
from app.services.training_service import TrainingService
from app.services.messaging import publish_job_started, publish_job_cancelled, publish_product_training_started
from shared.monitoring.metrics import MetricsCollector
logger = logging.getLogger(__name__)
router = APIRouter()
metrics = MetricsCollector("training-service")
# Initialize training service
training_service = TrainingService()
@router.post("/jobs", response_model=TrainingJobResponse)
async def start_training_job(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Start a new training job for all products of a tenant.
Replaces the old Celery-based training system.
"""
try:
logger.info(f"Starting training job for tenant {tenant_id}")
metrics.increment_counter("training_jobs_started")
# Generate job ID
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
# Create training job record
training_job = await training_service.create_training_job(
db=db,
tenant_id=tenant_id,
job_id=job_id,
config=request.dict()
)
# Start training in background
background_tasks.add_task(
training_service.execute_training_job,
db,
job_id,
tenant_id,
request
)
# Publish training started event
await publish_job_started(job_id, tenant_id, request.dict())
return TrainingJobResponse(
job_id=job_id,
status="started",
message="Training job started successfully",
tenant_id=tenant_id,
created_at=training_job.start_time,
estimated_duration_minutes=request.estimated_duration or 15
)
except Exception as e:
logger.error(f"Failed to start training job: {str(e)}")
metrics.increment_counter("training_jobs_failed")
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
@router.get("/jobs/{job_id}/status", response_model=TrainingStatusResponse)
async def get_training_status(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Get the status of a training job.
Provides real-time progress updates.
"""
try:
# Get job status from database
job_status = await training_service.get_job_status(db, job_id, tenant_id)
if not job_status:
raise HTTPException(status_code=404, detail="Training job not found")
return TrainingStatusResponse(
job_id=job_id,
status=job_status.status,
progress=job_status.progress,
current_step=job_status.current_step,
started_at=job_status.start_time,
completed_at=job_status.end_time,
results=job_status.results,
error_message=job_status.error_message
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get training status: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get training status: {str(e)}")
@router.post("/products/{product_name}", response_model=TrainingJobResponse)
async def train_single_product(
product_name: str,
request: SingleProductTrainingRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Train a model for a single product.
Useful for quick model updates or new products.
"""
try:
logger.info(f"Starting single product training for {product_name}, tenant {tenant_id}")
metrics.increment_counter("single_product_training_started")
# Generate job ID
job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
# Create training job record
training_job = await training_service.create_single_product_job(
db=db,
tenant_id=tenant_id,
product_name=product_name,
job_id=job_id,
config=request.dict()
)
# Start training in background
background_tasks.add_task(
training_service.execute_single_product_training,
db,
job_id,
tenant_id,
product_name,
request
)
# Publish event
await publish_product_training_started(job_id, tenant_id, product_name)
return TrainingJobResponse(
job_id=job_id,
status="started",
message=f"Single product training started for {product_name}",
tenant_id=tenant_id,
created_at=training_job.start_time,
estimated_duration_minutes=5
)
except Exception as e:
logger.error(f"Failed to start single product training: {str(e)}")
metrics.increment_counter("single_product_training_failed")
raise HTTPException(status_code=500, detail=f"Failed to start training: {str(e)}")
@router.get("/jobs", response_model=List[TrainingStatusResponse])
async def list_training_jobs(
limit: int = 10,
status: Optional[str] = None,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
List training jobs for a tenant.
"""
try:
jobs = await training_service.list_training_jobs(
db=db,
tenant_id=tenant_id,
limit=limit,
status_filter=status
)
return [
TrainingStatusResponse(
job_id=job.job_id,
status=job.status,
progress=job.progress,
current_step=job.current_step,
started_at=job.start_time,
completed_at=job.end_time,
results=job.results,
error_message=job.error_message
)
for job in jobs
]
except Exception as e:
logger.error(f"Failed to list training jobs: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to list training jobs: {str(e)}")
@router.post("/jobs/{job_id}/cancel")
async def cancel_training_job(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Cancel a running training job.
"""
try:
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
# Update job status to cancelled
success = await training_service.cancel_training_job(db, job_id, tenant_id)
if not success:
raise HTTPException(status_code=404, detail="Training job not found or cannot be cancelled")
# Publish cancellation event
await publish_job_cancelled(job_id, tenant_id)
return {"message": "Training job cancelled successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to cancel training job: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
@router.get("/jobs/{job_id}/logs")
async def get_training_logs(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Get detailed logs for a training job.
"""
try:
logs = await training_service.get_training_logs(db, job_id, tenant_id)
if not logs:
raise HTTPException(status_code=404, detail="Training job not found")
return {"job_id": job_id, "logs": logs}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get training logs: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get training logs: {str(e)}")
@router.post("/validate")
async def validate_training_data(
request: TrainingJobRequest,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Validate training data before starting a job.
Provides early feedback on data quality issues.
"""
try:
logger.info(f"Validating training data for tenant {tenant_id}")
# Perform data validation
validation_result = await training_service.validate_training_data(
db=db,
tenant_id=tenant_id,
config=request.dict()
)
return {
"is_valid": validation_result["is_valid"],
"issues": validation_result.get("issues", []),
"recommendations": validation_result.get("recommendations", []),
"estimated_training_time": validation_result.get("estimated_time_minutes", 15)
}
except Exception as e:
logger.error(f"Failed to validate training data: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to validate training data: {str(e)}")
@router.get("/health")
async def health_check():
"""Health check for the training service"""
return {
"status": "healthy",
"service": "training-service",
"timestamp": datetime.now().isoformat()
}