# services/training/app/api/training.py """ Training API endpoints for the training service """ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from sqlalchemy.ext.asyncio import AsyncSession from typing import Dict, List, Any, Optional import logging from datetime import datetime import uuid from app.core.database import get_db from app.core.auth import get_current_tenant_id from app.schemas.training import ( TrainingJobRequest, TrainingJobResponse, TrainingStatusResponse, SingleProductTrainingRequest ) from app.services.training_service import TrainingService from app.services.messaging import publish_job_started, publish_job_cancelled, publish_product_training_started from shared.monitoring.metrics import MetricsCollector logger = logging.getLogger(__name__) router = APIRouter() metrics = MetricsCollector("training-service") # Initialize training service training_service = TrainingService() @router.post("/jobs", response_model=TrainingJobResponse) async def start_training_job( request: TrainingJobRequest, background_tasks: BackgroundTasks, tenant_id: str = Depends(get_current_tenant_id), db: AsyncSession = Depends(get_db) ): """ Start a new training job for all products of a tenant. Replaces the old Celery-based training system. """ try: logger.info(f"Starting training job for tenant {tenant_id}") metrics.increment_counter("training_jobs_started") # Generate job ID job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}" # Create training job record training_job = await training_service.create_training_job( db=db, tenant_id=tenant_id, job_id=job_id, config=request.dict() ) # Start training in background background_tasks.add_task( training_service.execute_training_job, db, job_id, tenant_id, request ) # Publish training started event await publish_job_started(job_id, tenant_id, request.dict()) return TrainingJobResponse( job_id=job_id, status="started", message="Training job started successfully", tenant_id=tenant_id, created_at=training_job.start_time, estimated_duration_minutes=request.estimated_duration or 15 ) except Exception as e: logger.error(f"Failed to start training job: {str(e)}") metrics.increment_counter("training_jobs_failed") raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}") @router.get("/jobs/{job_id}/status", response_model=TrainingStatusResponse) async def get_training_status( job_id: str, tenant_id: str = Depends(get_current_tenant_id), db: AsyncSession = Depends(get_db) ): """ Get the status of a training job. Provides real-time progress updates. """ try: # Get job status from database job_status = await training_service.get_job_status(db, job_id, tenant_id) if not job_status: raise HTTPException(status_code=404, detail="Training job not found") return TrainingStatusResponse( job_id=job_id, status=job_status.status, progress=job_status.progress, current_step=job_status.current_step, started_at=job_status.start_time, completed_at=job_status.end_time, results=job_status.results, error_message=job_status.error_message ) except HTTPException: raise except Exception as e: logger.error(f"Failed to get training status: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to get training status: {str(e)}") @router.post("/products/{product_name}", response_model=TrainingJobResponse) async def train_single_product( product_name: str, request: SingleProductTrainingRequest, background_tasks: BackgroundTasks, tenant_id: str = Depends(get_current_tenant_id), db: AsyncSession = Depends(get_db) ): """ Train a model for a single product. Useful for quick model updates or new products. """ try: logger.info(f"Starting single product training for {product_name}, tenant {tenant_id}") metrics.increment_counter("single_product_training_started") # Generate job ID job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}" # Create training job record training_job = await training_service.create_single_product_job( db=db, tenant_id=tenant_id, product_name=product_name, job_id=job_id, config=request.dict() ) # Start training in background background_tasks.add_task( training_service.execute_single_product_training, db, job_id, tenant_id, product_name, request ) # Publish event await publish_product_training_started(job_id, tenant_id, product_name) return TrainingJobResponse( job_id=job_id, status="started", message=f"Single product training started for {product_name}", tenant_id=tenant_id, created_at=training_job.start_time, estimated_duration_minutes=5 ) except Exception as e: logger.error(f"Failed to start single product training: {str(e)}") metrics.increment_counter("single_product_training_failed") raise HTTPException(status_code=500, detail=f"Failed to start training: {str(e)}") @router.get("/jobs", response_model=List[TrainingStatusResponse]) async def list_training_jobs( limit: int = 10, status: Optional[str] = None, tenant_id: str = Depends(get_current_tenant_id), db: AsyncSession = Depends(get_db) ): """ List training jobs for a tenant. """ try: jobs = await training_service.list_training_jobs( db=db, tenant_id=tenant_id, limit=limit, status_filter=status ) return [ TrainingStatusResponse( job_id=job.job_id, status=job.status, progress=job.progress, current_step=job.current_step, started_at=job.start_time, completed_at=job.end_time, results=job.results, error_message=job.error_message ) for job in jobs ] except Exception as e: logger.error(f"Failed to list training jobs: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to list training jobs: {str(e)}") @router.post("/jobs/{job_id}/cancel") async def cancel_training_job( job_id: str, tenant_id: str = Depends(get_current_tenant_id), db: AsyncSession = Depends(get_db) ): """ Cancel a running training job. """ try: logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}") # Update job status to cancelled success = await training_service.cancel_training_job(db, job_id, tenant_id) if not success: raise HTTPException(status_code=404, detail="Training job not found or cannot be cancelled") # Publish cancellation event await publish_job_cancelled(job_id, tenant_id) return {"message": "Training job cancelled successfully"} except HTTPException: raise except Exception as e: logger.error(f"Failed to cancel training job: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}") @router.get("/jobs/{job_id}/logs") async def get_training_logs( job_id: str, tenant_id: str = Depends(get_current_tenant_id), db: AsyncSession = Depends(get_db) ): """ Get detailed logs for a training job. """ try: logs = await training_service.get_training_logs(db, job_id, tenant_id) if not logs: raise HTTPException(status_code=404, detail="Training job not found") return {"job_id": job_id, "logs": logs} except HTTPException: raise except Exception as e: logger.error(f"Failed to get training logs: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to get training logs: {str(e)}") @router.post("/validate") async def validate_training_data( request: TrainingJobRequest, tenant_id: str = Depends(get_current_tenant_id), db: AsyncSession = Depends(get_db) ): """ Validate training data before starting a job. Provides early feedback on data quality issues. """ try: logger.info(f"Validating training data for tenant {tenant_id}") # Perform data validation validation_result = await training_service.validate_training_data( db=db, tenant_id=tenant_id, config=request.dict() ) return { "is_valid": validation_result["is_valid"], "issues": validation_result.get("issues", []), "recommendations": validation_result.get("recommendations", []), "estimated_training_time": validation_result.get("estimated_time_minutes", 15) } except Exception as e: logger.error(f"Failed to validate training data: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to validate training data: {str(e)}") @router.get("/health") async def health_check(): """Health check for the training service""" return { "status": "healthy", "service": "training-service", "timestamp": datetime.now().isoformat() }