Files
bakery-ia/services/training/app/api/training.py

299 lines
9.8 KiB
Python
Raw Normal View History

2025-07-19 16:59:37 +02:00
# services/training/app/api/training.py
"""
2025-07-19 16:59:37 +02:00
Training API endpoints for the training service
"""
2025-07-19 16:59:37 +02:00
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
2025-07-19 16:59:37 +02:00
from typing import Dict, List, Any, Optional
import logging
from datetime import datetime
import uuid
from app.core.database import get_db
2025-07-19 16:59:37 +02:00
from app.core.auth import get_current_tenant_id
from app.schemas.training import (
TrainingJobRequest,
TrainingJobResponse,
TrainingStatusResponse,
SingleProductTrainingRequest
)
from app.services.training_service import TrainingService
2025-07-19 16:59:37 +02:00
from app.services.messaging import publish_job_started, publish_job_cancelled, publish_product_training_started
from shared.monitoring.metrics import MetricsCollector
2025-07-19 16:59:37 +02:00
logger = logging.getLogger(__name__)
router = APIRouter()
2025-07-19 16:59:37 +02:00
metrics = MetricsCollector("training-service")
2025-07-19 16:59:37 +02:00
# Initialize training service
training_service = TrainingService()
2025-07-19 16:59:37 +02:00
@router.post("/jobs", response_model=TrainingJobResponse)
async def start_training_job(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
2025-07-19 16:59:37 +02:00
"""
Start a new training job for all products of a tenant.
Replaces the old Celery-based training system.
"""
try:
2025-07-19 16:59:37 +02:00
logger.info(f"Starting training job for tenant {tenant_id}")
metrics.increment_counter("training_jobs_started")
# Generate job ID
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
# Create training job record
training_job = await training_service.create_training_job(
db=db,
tenant_id=tenant_id,
job_id=job_id,
config=request.dict()
)
2025-07-19 16:59:37 +02:00
# Start training in background
background_tasks.add_task(
training_service.execute_training_job,
db,
job_id,
tenant_id,
request
)
# Publish training started event
await publish_job_started(job_id, tenant_id, request.dict())
return TrainingJobResponse(
job_id=job_id,
status="started",
message="Training job started successfully",
tenant_id=tenant_id,
created_at=training_job.start_time,
estimated_duration_minutes=request.estimated_duration or 15
)
2025-07-19 16:59:37 +02:00
except Exception as e:
logger.error(f"Failed to start training job: {str(e)}")
metrics.increment_counter("training_jobs_failed")
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
2025-07-19 16:59:37 +02:00
@router.get("/jobs/{job_id}/status", response_model=TrainingStatusResponse)
async def get_training_status(
job_id: str,
2025-07-19 16:59:37 +02:00
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
2025-07-19 16:59:37 +02:00
"""
Get the status of a training job.
Provides real-time progress updates.
"""
try:
2025-07-19 16:59:37 +02:00
# Get job status from database
job_status = await training_service.get_job_status(db, job_id, tenant_id)
if not job_status:
raise HTTPException(status_code=404, detail="Training job not found")
return TrainingStatusResponse(
job_id=job_id,
status=job_status.status,
progress=job_status.progress,
current_step=job_status.current_step,
started_at=job_status.start_time,
completed_at=job_status.end_time,
results=job_status.results,
error_message=job_status.error_message
)
2025-07-19 16:59:37 +02:00
except HTTPException:
raise
except Exception as e:
2025-07-19 16:59:37 +02:00
logger.error(f"Failed to get training status: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get training status: {str(e)}")
@router.post("/products/{product_name}", response_model=TrainingJobResponse)
async def train_single_product(
product_name: str,
request: SingleProductTrainingRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Train a model for a single product.
Useful for quick model updates or new products.
"""
try:
logger.info(f"Starting single product training for {product_name}, tenant {tenant_id}")
metrics.increment_counter("single_product_training_started")
# Generate job ID
job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
# Create training job record
training_job = await training_service.create_single_product_job(
db=db,
tenant_id=tenant_id,
product_name=product_name,
job_id=job_id,
config=request.dict()
)
# Start training in background
background_tasks.add_task(
training_service.execute_single_product_training,
db,
job_id,
tenant_id,
product_name,
request
)
# Publish event
await publish_product_training_started(job_id, tenant_id, product_name)
return TrainingJobResponse(
job_id=job_id,
status="started",
message=f"Single product training started for {product_name}",
tenant_id=tenant_id,
created_at=training_job.start_time,
estimated_duration_minutes=5
)
except Exception as e:
logger.error(f"Failed to start single product training: {str(e)}")
metrics.increment_counter("single_product_training_failed")
raise HTTPException(status_code=500, detail=f"Failed to start training: {str(e)}")
@router.get("/jobs", response_model=List[TrainingStatusResponse])
async def list_training_jobs(
limit: int = 10,
status: Optional[str] = None,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
List training jobs for a tenant.
"""
try:
jobs = await training_service.list_training_jobs(
db=db,
tenant_id=tenant_id,
limit=limit,
status_filter=status
)
2025-07-19 16:59:37 +02:00
return [
TrainingStatusResponse(
job_id=job.job_id,
status=job.status,
progress=job.progress,
current_step=job.current_step,
started_at=job.start_time,
completed_at=job.end_time,
results=job.results,
error_message=job.error_message
)
for job in jobs
]
except Exception as e:
logger.error(f"Failed to list training jobs: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to list training jobs: {str(e)}")
@router.post("/jobs/{job_id}/cancel")
async def cancel_training_job(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Cancel a running training job.
"""
try:
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
# Update job status to cancelled
success = await training_service.cancel_training_job(db, job_id, tenant_id)
if not success:
raise HTTPException(status_code=404, detail="Training job not found or cannot be cancelled")
# Publish cancellation event
await publish_job_cancelled(job_id, tenant_id)
return {"message": "Training job cancelled successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to cancel training job: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
2025-07-19 16:59:37 +02:00
@router.get("/jobs/{job_id}/logs")
async def get_training_logs(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
2025-07-19 16:59:37 +02:00
"""
Get detailed logs for a training job.
"""
try:
2025-07-19 16:59:37 +02:00
logs = await training_service.get_training_logs(db, job_id, tenant_id)
if not logs:
raise HTTPException(status_code=404, detail="Training job not found")
return {"job_id": job_id, "logs": logs}
except HTTPException:
raise
except Exception as e:
2025-07-19 16:59:37 +02:00
logger.error(f"Failed to get training logs: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get training logs: {str(e)}")
@router.post("/validate")
async def validate_training_data(
request: TrainingJobRequest,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Validate training data before starting a job.
Provides early feedback on data quality issues.
"""
try:
logger.info(f"Validating training data for tenant {tenant_id}")
# Perform data validation
validation_result = await training_service.validate_training_data(
db=db,
tenant_id=tenant_id,
config=request.dict()
)
return {
"is_valid": validation_result["is_valid"],
"issues": validation_result.get("issues", []),
"recommendations": validation_result.get("recommendations", []),
"estimated_training_time": validation_result.get("estimated_time_minutes", 15)
}
except Exception as e:
logger.error(f"Failed to validate training data: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to validate training data: {str(e)}")
@router.get("/health")
async def health_check():
"""Health check for the training service"""
return {
"status": "healthy",
"service": "training-service",
"timestamp": datetime.now().isoformat()
}