2025-07-19 16:59:37 +02:00
|
|
|
# services/training/app/api/training.py
|
2025-07-17 13:09:24 +02:00
|
|
|
"""
|
2025-07-19 16:59:37 +02:00
|
|
|
Training API endpoints for the training service
|
2025-07-17 13:09:24 +02:00
|
|
|
"""
|
|
|
|
|
|
2025-07-19 16:59:37 +02:00
|
|
|
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
2025-07-17 13:09:24 +02:00
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
2025-07-19 16:59:37 +02:00
|
|
|
from typing import Dict, List, Any, Optional
|
|
|
|
|
import logging
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
import uuid
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
from app.core.database import get_db
|
2025-07-19 16:59:37 +02:00
|
|
|
from app.core.auth import get_current_tenant_id
|
|
|
|
|
from app.schemas.training import (
|
|
|
|
|
TrainingJobRequest,
|
|
|
|
|
TrainingJobResponse,
|
|
|
|
|
TrainingStatusResponse,
|
|
|
|
|
SingleProductTrainingRequest
|
|
|
|
|
)
|
2025-07-17 13:09:24 +02:00
|
|
|
from app.services.training_service import TrainingService
|
2025-07-19 16:59:37 +02:00
|
|
|
from app.services.messaging import publish_job_started, publish_job_cancelled, publish_product_training_started
|
|
|
|
|
from shared.monitoring.metrics import MetricsCollector
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-19 16:59:37 +02:00
|
|
|
logger = logging.getLogger(__name__)
|
2025-07-17 13:09:24 +02:00
|
|
|
router = APIRouter()
|
2025-07-19 16:59:37 +02:00
|
|
|
metrics = MetricsCollector("training-service")
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-19 16:59:37 +02:00
|
|
|
# Initialize training service
|
2025-07-17 13:09:24 +02:00
|
|
|
training_service = TrainingService()
|
|
|
|
|
|
2025-07-19 16:59:37 +02:00
|
|
|
@router.post("/jobs", response_model=TrainingJobResponse)
|
|
|
|
|
async def start_training_job(
|
|
|
|
|
request: TrainingJobRequest,
|
|
|
|
|
background_tasks: BackgroundTasks,
|
|
|
|
|
tenant_id: str = Depends(get_current_tenant_id),
|
2025-07-17 13:09:24 +02:00
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
2025-07-19 16:59:37 +02:00
|
|
|
"""
|
|
|
|
|
Start a new training job for all products of a tenant.
|
|
|
|
|
Replaces the old Celery-based training system.
|
|
|
|
|
"""
|
2025-07-17 13:09:24 +02:00
|
|
|
try:
|
2025-07-19 16:59:37 +02:00
|
|
|
logger.info(f"Starting training job for tenant {tenant_id}")
|
|
|
|
|
metrics.increment_counter("training_jobs_started")
|
|
|
|
|
|
|
|
|
|
# Generate job ID
|
|
|
|
|
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
|
|
|
|
|
|
|
|
|
# Create training job record
|
|
|
|
|
training_job = await training_service.create_training_job(
|
|
|
|
|
db=db,
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
config=request.dict()
|
2025-07-17 13:09:24 +02:00
|
|
|
)
|
2025-07-19 16:59:37 +02:00
|
|
|
|
|
|
|
|
# Start training in background
|
|
|
|
|
background_tasks.add_task(
|
|
|
|
|
training_service.execute_training_job,
|
|
|
|
|
db,
|
|
|
|
|
job_id,
|
|
|
|
|
tenant_id,
|
|
|
|
|
request
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Publish training started event
|
|
|
|
|
await publish_job_started(job_id, tenant_id, request.dict())
|
|
|
|
|
|
|
|
|
|
return TrainingJobResponse(
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
status="started",
|
|
|
|
|
message="Training job started successfully",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
created_at=training_job.start_time,
|
|
|
|
|
estimated_duration_minutes=request.estimated_duration or 15
|
2025-07-17 13:09:24 +02:00
|
|
|
)
|
2025-07-19 16:59:37 +02:00
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to start training job: {str(e)}")
|
|
|
|
|
metrics.increment_counter("training_jobs_failed")
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-19 16:59:37 +02:00
|
|
|
@router.get("/jobs/{job_id}/status", response_model=TrainingStatusResponse)
|
2025-07-17 13:09:24 +02:00
|
|
|
async def get_training_status(
|
|
|
|
|
job_id: str,
|
2025-07-19 16:59:37 +02:00
|
|
|
tenant_id: str = Depends(get_current_tenant_id),
|
2025-07-17 13:09:24 +02:00
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
2025-07-19 16:59:37 +02:00
|
|
|
"""
|
|
|
|
|
Get the status of a training job.
|
|
|
|
|
Provides real-time progress updates.
|
|
|
|
|
"""
|
2025-07-17 13:09:24 +02:00
|
|
|
try:
|
2025-07-19 16:59:37 +02:00
|
|
|
# Get job status from database
|
|
|
|
|
job_status = await training_service.get_job_status(db, job_id, tenant_id)
|
|
|
|
|
|
|
|
|
|
if not job_status:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Training job not found")
|
|
|
|
|
|
|
|
|
|
return TrainingStatusResponse(
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
status=job_status.status,
|
|
|
|
|
progress=job_status.progress,
|
|
|
|
|
current_step=job_status.current_step,
|
|
|
|
|
started_at=job_status.start_time,
|
|
|
|
|
completed_at=job_status.end_time,
|
|
|
|
|
results=job_status.results,
|
|
|
|
|
error_message=job_status.error_message
|
2025-07-17 13:09:24 +02:00
|
|
|
)
|
2025-07-19 16:59:37 +02:00
|
|
|
|
|
|
|
|
except HTTPException:
|
|
|
|
|
raise
|
2025-07-17 13:09:24 +02:00
|
|
|
except Exception as e:
|
2025-07-19 16:59:37 +02:00
|
|
|
logger.error(f"Failed to get training status: {str(e)}")
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to get training status: {str(e)}")
|
|
|
|
|
|
|
|
|
|
@router.post("/products/{product_name}", response_model=TrainingJobResponse)
|
|
|
|
|
async def train_single_product(
|
|
|
|
|
product_name: str,
|
|
|
|
|
request: SingleProductTrainingRequest,
|
|
|
|
|
background_tasks: BackgroundTasks,
|
|
|
|
|
tenant_id: str = Depends(get_current_tenant_id),
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Train a model for a single product.
|
|
|
|
|
Useful for quick model updates or new products.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
logger.info(f"Starting single product training for {product_name}, tenant {tenant_id}")
|
|
|
|
|
metrics.increment_counter("single_product_training_started")
|
|
|
|
|
|
|
|
|
|
# Generate job ID
|
|
|
|
|
job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
|
|
|
|
|
|
|
|
|
|
# Create training job record
|
|
|
|
|
training_job = await training_service.create_single_product_job(
|
|
|
|
|
db=db,
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
product_name=product_name,
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
config=request.dict()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Start training in background
|
|
|
|
|
background_tasks.add_task(
|
|
|
|
|
training_service.execute_single_product_training,
|
|
|
|
|
db,
|
|
|
|
|
job_id,
|
|
|
|
|
tenant_id,
|
|
|
|
|
product_name,
|
|
|
|
|
request
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Publish event
|
|
|
|
|
await publish_product_training_started(job_id, tenant_id, product_name)
|
|
|
|
|
|
|
|
|
|
return TrainingJobResponse(
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
status="started",
|
|
|
|
|
message=f"Single product training started for {product_name}",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
created_at=training_job.start_time,
|
|
|
|
|
estimated_duration_minutes=5
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to start single product training: {str(e)}")
|
|
|
|
|
metrics.increment_counter("single_product_training_failed")
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to start training: {str(e)}")
|
|
|
|
|
|
|
|
|
|
@router.get("/jobs", response_model=List[TrainingStatusResponse])
|
|
|
|
|
async def list_training_jobs(
|
|
|
|
|
limit: int = 10,
|
|
|
|
|
status: Optional[str] = None,
|
|
|
|
|
tenant_id: str = Depends(get_current_tenant_id),
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
List training jobs for a tenant.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
jobs = await training_service.list_training_jobs(
|
|
|
|
|
db=db,
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
limit=limit,
|
|
|
|
|
status_filter=status
|
2025-07-17 13:09:24 +02:00
|
|
|
)
|
2025-07-19 16:59:37 +02:00
|
|
|
|
|
|
|
|
return [
|
|
|
|
|
TrainingStatusResponse(
|
|
|
|
|
job_id=job.job_id,
|
|
|
|
|
status=job.status,
|
|
|
|
|
progress=job.progress,
|
|
|
|
|
current_step=job.current_step,
|
|
|
|
|
started_at=job.start_time,
|
|
|
|
|
completed_at=job.end_time,
|
|
|
|
|
results=job.results,
|
|
|
|
|
error_message=job.error_message
|
|
|
|
|
)
|
|
|
|
|
for job in jobs
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to list training jobs: {str(e)}")
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to list training jobs: {str(e)}")
|
|
|
|
|
|
|
|
|
|
@router.post("/jobs/{job_id}/cancel")
|
|
|
|
|
async def cancel_training_job(
|
|
|
|
|
job_id: str,
|
|
|
|
|
tenant_id: str = Depends(get_current_tenant_id),
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Cancel a running training job.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
|
|
|
|
|
|
|
|
|
|
# Update job status to cancelled
|
|
|
|
|
success = await training_service.cancel_training_job(db, job_id, tenant_id)
|
|
|
|
|
|
|
|
|
|
if not success:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Training job not found or cannot be cancelled")
|
|
|
|
|
|
|
|
|
|
# Publish cancellation event
|
|
|
|
|
await publish_job_cancelled(job_id, tenant_id)
|
|
|
|
|
|
|
|
|
|
return {"message": "Training job cancelled successfully"}
|
|
|
|
|
|
|
|
|
|
except HTTPException:
|
|
|
|
|
raise
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to cancel training job: {str(e)}")
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-19 16:59:37 +02:00
|
|
|
@router.get("/jobs/{job_id}/logs")
|
|
|
|
|
async def get_training_logs(
|
|
|
|
|
job_id: str,
|
|
|
|
|
tenant_id: str = Depends(get_current_tenant_id),
|
2025-07-17 13:09:24 +02:00
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
2025-07-19 16:59:37 +02:00
|
|
|
"""
|
|
|
|
|
Get detailed logs for a training job.
|
|
|
|
|
"""
|
2025-07-17 13:09:24 +02:00
|
|
|
try:
|
2025-07-19 16:59:37 +02:00
|
|
|
logs = await training_service.get_training_logs(db, job_id, tenant_id)
|
|
|
|
|
|
|
|
|
|
if not logs:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Training job not found")
|
|
|
|
|
|
|
|
|
|
return {"job_id": job_id, "logs": logs}
|
|
|
|
|
|
|
|
|
|
except HTTPException:
|
|
|
|
|
raise
|
2025-07-17 13:09:24 +02:00
|
|
|
except Exception as e:
|
2025-07-19 16:59:37 +02:00
|
|
|
logger.error(f"Failed to get training logs: {str(e)}")
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to get training logs: {str(e)}")
|
|
|
|
|
|
|
|
|
|
@router.post("/validate")
|
|
|
|
|
async def validate_training_data(
|
|
|
|
|
request: TrainingJobRequest,
|
|
|
|
|
tenant_id: str = Depends(get_current_tenant_id),
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Validate training data before starting a job.
|
|
|
|
|
Provides early feedback on data quality issues.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
logger.info(f"Validating training data for tenant {tenant_id}")
|
|
|
|
|
|
|
|
|
|
# Perform data validation
|
|
|
|
|
validation_result = await training_service.validate_training_data(
|
|
|
|
|
db=db,
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
config=request.dict()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"is_valid": validation_result["is_valid"],
|
|
|
|
|
"issues": validation_result.get("issues", []),
|
|
|
|
|
"recommendations": validation_result.get("recommendations", []),
|
|
|
|
|
"estimated_training_time": validation_result.get("estimated_time_minutes", 15)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to validate training data: {str(e)}")
|
|
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to validate training data: {str(e)}")
|
|
|
|
|
|
|
|
|
|
@router.get("/health")
|
|
|
|
|
async def health_check():
|
|
|
|
|
"""Health check for the training service"""
|
|
|
|
|
return {
|
|
|
|
|
"status": "healthy",
|
|
|
|
|
"service": "training-service",
|
|
|
|
|
"timestamp": datetime.now().isoformat()
|
|
|
|
|
}
|