Add all the code for training service

This commit is contained in:
Urtzi Alfaro
2025-07-19 16:59:37 +02:00
parent 42097202d2
commit f3071c00bd
21 changed files with 7504 additions and 764 deletions

View File

@@ -8,10 +8,11 @@ from typing import List
import structlog
from app.core.database import get_db
from app.core.auth import verify_token
from app.core.auth import get_current_tenant_id
from app.schemas.training import TrainedModelResponse
from app.services.training_service import TrainingService
logger = structlog.get_logger()
router = APIRouter()
@@ -19,12 +20,12 @@ training_service = TrainingService()
@router.get("/", response_model=List[TrainedModelResponse])
async def get_trained_models(
user_data: dict = Depends(verify_token),
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""Get trained models"""
try:
return await training_service.get_trained_models(user_data, db)
return await training_service.get_trained_models(tenant_id, db)
except Exception as e:
logger.error(f"Get trained models error: {e}")
raise HTTPException(

View File

@@ -1,77 +1,299 @@
# services/training/app/api/training.py
"""
Training API endpoints
Training API endpoints for the training service
"""
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
import structlog
from typing import Dict, List, Any, Optional
import logging
from datetime import datetime
import uuid
from app.core.database import get_db
from app.core.auth import verify_token
from app.schemas.training import TrainingRequest, TrainingJobResponse, TrainedModelResponse
from app.core.auth import get_current_tenant_id
from app.schemas.training import (
TrainingJobRequest,
TrainingJobResponse,
TrainingStatusResponse,
SingleProductTrainingRequest
)
from app.services.training_service import TrainingService
from app.services.messaging import publish_job_started, publish_job_cancelled, publish_product_training_started
from shared.monitoring.metrics import MetricsCollector
logger = structlog.get_logger()
logger = logging.getLogger(__name__)
router = APIRouter()
metrics = MetricsCollector("training-service")
# Initialize training service
training_service = TrainingService()
@router.post("/train", response_model=TrainingJobResponse)
async def start_training(
request: TrainingRequest,
user_data: dict = Depends(verify_token),
@router.post("/jobs", response_model=TrainingJobResponse)
async def start_training_job(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""Start training job"""
"""
Start a new training job for all products of a tenant.
Replaces the old Celery-based training system.
"""
try:
return await training_service.start_training(request, user_data, db)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
logger.info(f"Starting training job for tenant {tenant_id}")
metrics.increment_counter("training_jobs_started")
# Generate job ID
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
# Create training job record
training_job = await training_service.create_training_job(
db=db,
tenant_id=tenant_id,
job_id=job_id,
config=request.dict()
)
# Start training in background
background_tasks.add_task(
training_service.execute_training_job,
db,
job_id,
tenant_id,
request
)
# Publish training started event
await publish_job_started(job_id, tenant_id, request.dict())
return TrainingJobResponse(
job_id=job_id,
status="started",
message="Training job started successfully",
tenant_id=tenant_id,
created_at=training_job.start_time,
estimated_duration_minutes=request.estimated_duration or 15
)
except Exception as e:
logger.error(f"Training start error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to start training"
)
logger.error(f"Failed to start training job: {str(e)}")
metrics.increment_counter("training_jobs_failed")
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
@router.get("/status/{job_id}", response_model=TrainingJobResponse)
@router.get("/jobs/{job_id}/status", response_model=TrainingStatusResponse)
async def get_training_status(
job_id: str,
user_data: dict = Depends(verify_token),
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""Get training job status"""
"""
Get the status of a training job.
Provides real-time progress updates.
"""
try:
return await training_service.get_training_status(job_id, user_data, db)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e)
# Get job status from database
job_status = await training_service.get_job_status(db, job_id, tenant_id)
if not job_status:
raise HTTPException(status_code=404, detail="Training job not found")
return TrainingStatusResponse(
job_id=job_id,
status=job_status.status,
progress=job_status.progress,
current_step=job_status.current_step,
started_at=job_status.start_time,
completed_at=job_status.end_time,
results=job_status.results,
error_message=job_status.error_message
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Get training status error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training status"
)
logger.error(f"Failed to get training status: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get training status: {str(e)}")
@router.get("/jobs", response_model=List[TrainingJobResponse])
async def get_training_jobs(
limit: int = Query(10, ge=1, le=100),
offset: int = Query(0, ge=0),
user_data: dict = Depends(verify_token),
@router.post("/products/{product_name}", response_model=TrainingJobResponse)
async def train_single_product(
product_name: str,
request: SingleProductTrainingRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""Get training jobs"""
"""
Train a model for a single product.
Useful for quick model updates or new products.
"""
try:
return await training_service.get_training_jobs(user_data, limit, offset, db)
logger.info(f"Starting single product training for {product_name}, tenant {tenant_id}")
metrics.increment_counter("single_product_training_started")
# Generate job ID
job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
# Create training job record
training_job = await training_service.create_single_product_job(
db=db,
tenant_id=tenant_id,
product_name=product_name,
job_id=job_id,
config=request.dict()
)
# Start training in background
background_tasks.add_task(
training_service.execute_single_product_training,
db,
job_id,
tenant_id,
product_name,
request
)
# Publish event
await publish_product_training_started(job_id, tenant_id, product_name)
return TrainingJobResponse(
job_id=job_id,
status="started",
message=f"Single product training started for {product_name}",
tenant_id=tenant_id,
created_at=training_job.start_time,
estimated_duration_minutes=5
)
except Exception as e:
logger.error(f"Get training jobs error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training jobs"
)
logger.error(f"Failed to start single product training: {str(e)}")
metrics.increment_counter("single_product_training_failed")
raise HTTPException(status_code=500, detail=f"Failed to start training: {str(e)}")
@router.get("/jobs", response_model=List[TrainingStatusResponse])
async def list_training_jobs(
limit: int = 10,
status: Optional[str] = None,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
List training jobs for a tenant.
"""
try:
jobs = await training_service.list_training_jobs(
db=db,
tenant_id=tenant_id,
limit=limit,
status_filter=status
)
return [
TrainingStatusResponse(
job_id=job.job_id,
status=job.status,
progress=job.progress,
current_step=job.current_step,
started_at=job.start_time,
completed_at=job.end_time,
results=job.results,
error_message=job.error_message
)
for job in jobs
]
except Exception as e:
logger.error(f"Failed to list training jobs: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to list training jobs: {str(e)}")
@router.post("/jobs/{job_id}/cancel")
async def cancel_training_job(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Cancel a running training job.
"""
try:
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
# Update job status to cancelled
success = await training_service.cancel_training_job(db, job_id, tenant_id)
if not success:
raise HTTPException(status_code=404, detail="Training job not found or cannot be cancelled")
# Publish cancellation event
await publish_job_cancelled(job_id, tenant_id)
return {"message": "Training job cancelled successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to cancel training job: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
@router.get("/jobs/{job_id}/logs")
async def get_training_logs(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Get detailed logs for a training job.
"""
try:
logs = await training_service.get_training_logs(db, job_id, tenant_id)
if not logs:
raise HTTPException(status_code=404, detail="Training job not found")
return {"job_id": job_id, "logs": logs}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get training logs: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get training logs: {str(e)}")
@router.post("/validate")
async def validate_training_data(
request: TrainingJobRequest,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Validate training data before starting a job.
Provides early feedback on data quality issues.
"""
try:
logger.info(f"Validating training data for tenant {tenant_id}")
# Perform data validation
validation_result = await training_service.validate_training_data(
db=db,
tenant_id=tenant_id,
config=request.dict()
)
return {
"is_valid": validation_result["is_valid"],
"issues": validation_result.get("issues", []),
"recommendations": validation_result.get("recommendations", []),
"estimated_training_time": validation_result.get("estimated_time_minutes", 15)
}
except Exception as e:
logger.error(f"Failed to validate training data: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to validate training data: {str(e)}")
@router.get("/health")
async def health_check():
"""Health check for the training service"""
return {
"status": "healthy",
"service": "training-service",
"timestamp": datetime.now().isoformat()
}