Add all the code for training service
This commit is contained in:
@@ -1,77 +1,299 @@
|
||||
# services/training/app/api/training.py
|
||||
"""
|
||||
Training API endpoints
|
||||
Training API endpoints for the training service
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Optional
|
||||
import structlog
|
||||
from typing import Dict, List, Any, Optional
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import verify_token
|
||||
from app.schemas.training import TrainingRequest, TrainingJobResponse, TrainedModelResponse
|
||||
from app.core.auth import get_current_tenant_id
|
||||
from app.schemas.training import (
|
||||
TrainingJobRequest,
|
||||
TrainingJobResponse,
|
||||
TrainingStatusResponse,
|
||||
SingleProductTrainingRequest
|
||||
)
|
||||
from app.services.training_service import TrainingService
|
||||
from app.services.messaging import publish_job_started, publish_job_cancelled, publish_product_training_started
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
logger = structlog.get_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
metrics = MetricsCollector("training-service")
|
||||
|
||||
# Initialize training service
|
||||
training_service = TrainingService()
|
||||
|
||||
@router.post("/train", response_model=TrainingJobResponse)
|
||||
async def start_training(
|
||||
request: TrainingRequest,
|
||||
user_data: dict = Depends(verify_token),
|
||||
@router.post("/jobs", response_model=TrainingJobResponse)
|
||||
async def start_training_job(
|
||||
request: TrainingJobRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Start training job"""
|
||||
"""
|
||||
Start a new training job for all products of a tenant.
|
||||
Replaces the old Celery-based training system.
|
||||
"""
|
||||
try:
|
||||
return await training_service.start_training(request, user_data, db)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
logger.info(f"Starting training job for tenant {tenant_id}")
|
||||
metrics.increment_counter("training_jobs_started")
|
||||
|
||||
# Generate job ID
|
||||
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create training job record
|
||||
training_job = await training_service.create_training_job(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
config=request.dict()
|
||||
)
|
||||
|
||||
# Start training in background
|
||||
background_tasks.add_task(
|
||||
training_service.execute_training_job,
|
||||
db,
|
||||
job_id,
|
||||
tenant_id,
|
||||
request
|
||||
)
|
||||
|
||||
# Publish training started event
|
||||
await publish_job_started(job_id, tenant_id, request.dict())
|
||||
|
||||
return TrainingJobResponse(
|
||||
job_id=job_id,
|
||||
status="started",
|
||||
message="Training job started successfully",
|
||||
tenant_id=tenant_id,
|
||||
created_at=training_job.start_time,
|
||||
estimated_duration_minutes=request.estimated_duration or 15
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training start error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to start training"
|
||||
)
|
||||
logger.error(f"Failed to start training job: {str(e)}")
|
||||
metrics.increment_counter("training_jobs_failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
|
||||
|
||||
@router.get("/status/{job_id}", response_model=TrainingJobResponse)
|
||||
@router.get("/jobs/{job_id}/status", response_model=TrainingStatusResponse)
|
||||
async def get_training_status(
|
||||
job_id: str,
|
||||
user_data: dict = Depends(verify_token),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get training job status"""
|
||||
"""
|
||||
Get the status of a training job.
|
||||
Provides real-time progress updates.
|
||||
"""
|
||||
try:
|
||||
return await training_service.get_training_status(job_id, user_data, db)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e)
|
||||
# Get job status from database
|
||||
job_status = await training_service.get_job_status(db, job_id, tenant_id)
|
||||
|
||||
if not job_status:
|
||||
raise HTTPException(status_code=404, detail="Training job not found")
|
||||
|
||||
return TrainingStatusResponse(
|
||||
job_id=job_id,
|
||||
status=job_status.status,
|
||||
progress=job_status.progress,
|
||||
current_step=job_status.current_step,
|
||||
started_at=job_status.start_time,
|
||||
completed_at=job_status.end_time,
|
||||
results=job_status.results,
|
||||
error_message=job_status.error_message
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Get training status error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get training status"
|
||||
)
|
||||
logger.error(f"Failed to get training status: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get training status: {str(e)}")
|
||||
|
||||
@router.get("/jobs", response_model=List[TrainingJobResponse])
|
||||
async def get_training_jobs(
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
user_data: dict = Depends(verify_token),
|
||||
@router.post("/products/{product_name}", response_model=TrainingJobResponse)
|
||||
async def train_single_product(
|
||||
product_name: str,
|
||||
request: SingleProductTrainingRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get training jobs"""
|
||||
"""
|
||||
Train a model for a single product.
|
||||
Useful for quick model updates or new products.
|
||||
"""
|
||||
try:
|
||||
return await training_service.get_training_jobs(user_data, limit, offset, db)
|
||||
logger.info(f"Starting single product training for {product_name}, tenant {tenant_id}")
|
||||
metrics.increment_counter("single_product_training_started")
|
||||
|
||||
# Generate job ID
|
||||
job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create training job record
|
||||
training_job = await training_service.create_single_product_job(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
job_id=job_id,
|
||||
config=request.dict()
|
||||
)
|
||||
|
||||
# Start training in background
|
||||
background_tasks.add_task(
|
||||
training_service.execute_single_product_training,
|
||||
db,
|
||||
job_id,
|
||||
tenant_id,
|
||||
product_name,
|
||||
request
|
||||
)
|
||||
|
||||
# Publish event
|
||||
await publish_product_training_started(job_id, tenant_id, product_name)
|
||||
|
||||
return TrainingJobResponse(
|
||||
job_id=job_id,
|
||||
status="started",
|
||||
message=f"Single product training started for {product_name}",
|
||||
tenant_id=tenant_id,
|
||||
created_at=training_job.start_time,
|
||||
estimated_duration_minutes=5
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Get training jobs error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get training jobs"
|
||||
)
|
||||
logger.error(f"Failed to start single product training: {str(e)}")
|
||||
metrics.increment_counter("single_product_training_failed")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start training: {str(e)}")
|
||||
|
||||
@router.get("/jobs", response_model=List[TrainingStatusResponse])
|
||||
async def list_training_jobs(
|
||||
limit: int = 10,
|
||||
status: Optional[str] = None,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
List training jobs for a tenant.
|
||||
"""
|
||||
try:
|
||||
jobs = await training_service.list_training_jobs(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
limit=limit,
|
||||
status_filter=status
|
||||
)
|
||||
|
||||
return [
|
||||
TrainingStatusResponse(
|
||||
job_id=job.job_id,
|
||||
status=job.status,
|
||||
progress=job.progress,
|
||||
current_step=job.current_step,
|
||||
started_at=job.start_time,
|
||||
completed_at=job.end_time,
|
||||
results=job.results,
|
||||
error_message=job.error_message
|
||||
)
|
||||
for job in jobs
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list training jobs: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list training jobs: {str(e)}")
|
||||
|
||||
@router.post("/jobs/{job_id}/cancel")
|
||||
async def cancel_training_job(
|
||||
job_id: str,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Cancel a running training job.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
|
||||
|
||||
# Update job status to cancelled
|
||||
success = await training_service.cancel_training_job(db, job_id, tenant_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Training job not found or cannot be cancelled")
|
||||
|
||||
# Publish cancellation event
|
||||
await publish_job_cancelled(job_id, tenant_id)
|
||||
|
||||
return {"message": "Training job cancelled successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel training job: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
|
||||
|
||||
@router.get("/jobs/{job_id}/logs")
|
||||
async def get_training_logs(
|
||||
job_id: str,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get detailed logs for a training job.
|
||||
"""
|
||||
try:
|
||||
logs = await training_service.get_training_logs(db, job_id, tenant_id)
|
||||
|
||||
if not logs:
|
||||
raise HTTPException(status_code=404, detail="Training job not found")
|
||||
|
||||
return {"job_id": job_id, "logs": logs}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get training logs: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get training logs: {str(e)}")
|
||||
|
||||
@router.post("/validate")
|
||||
async def validate_training_data(
|
||||
request: TrainingJobRequest,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Validate training data before starting a job.
|
||||
Provides early feedback on data quality issues.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Validating training data for tenant {tenant_id}")
|
||||
|
||||
# Perform data validation
|
||||
validation_result = await training_service.validate_training_data(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
config=request.dict()
|
||||
)
|
||||
|
||||
return {
|
||||
"is_valid": validation_result["is_valid"],
|
||||
"issues": validation_result.get("issues", []),
|
||||
"recommendations": validation_result.get("recommendations", []),
|
||||
"estimated_training_time": validation_result.get("estimated_time_minutes", 15)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate training data: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to validate training data: {str(e)}")
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""Health check for the training service"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "training-service",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
Reference in New Issue
Block a user