Improve gateway service

This commit is contained in:
Urtzi Alfaro
2025-07-20 07:24:04 +02:00
parent 1c730c3c81
commit 8cd433c0cd
4 changed files with 816 additions and 373 deletions

View File

@@ -1,299 +1,472 @@
# services/training/app/api/training.py
"""
Training API endpoints for the training service
"""
# ================================================================
# services/training/app/api/training.py - UPDATED WITH UNIFIED AUTH
# ================================================================
"""Training API endpoints with unified authentication"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Dict, List, Any, Optional
import logging
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query
from typing import List, Optional, Dict, Any
from datetime import datetime
import uuid
import structlog
from app.core.database import get_db
from app.core.auth import get_current_tenant_id
from app.schemas.training import (
TrainingJobRequest,
TrainingJobRequest,
TrainingJobResponse,
TrainingStatusResponse,
SingleProductTrainingRequest
TrainingJobStatus,
SingleProductTrainingRequest,
TrainingJobProgress,
DataValidationRequest,
DataValidationResponse
)
from app.services.training_service import TrainingService
from app.services.messaging import publish_job_started, publish_job_cancelled, publish_product_training_started
from shared.monitoring.metrics import MetricsCollector
from app.services.messaging import (
publish_job_started,
publish_job_completed,
publish_job_failed,
publish_job_progress,
publish_product_training_started,
publish_product_training_completed
)
logger = logging.getLogger(__name__)
router = APIRouter()
metrics = MetricsCollector("training-service")
# Import unified authentication from shared library
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep,
require_role
)
# Initialize training service
training_service = TrainingService()
logger = structlog.get_logger()
router = APIRouter(prefix="/training", tags=["training"])
@router.post("/jobs", response_model=TrainingJobResponse)
async def start_training_job(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""
Start a new training job for all products of a tenant.
Replaces the old Celery-based training system.
"""
"""Start a new training job for all products"""
try:
logger.info(f"Starting training job for tenant {tenant_id}")
metrics.increment_counter("training_jobs_started")
logger.info("Starting training job",
tenant_id=tenant_id,
user_id=current_user["user_id"],
config=request.dict())
# Generate job ID
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
# Create training job record
training_job = await training_service.create_training_job(
db=db,
# Create training job
job = await training_service.create_training_job(
tenant_id=tenant_id,
job_id=job_id,
user_id=current_user["user_id"],
config=request.dict()
)
# Publish job started event
try:
await publish_job_started(
job_id=job.job_id,
tenant_id=tenant_id,
config=request.dict()
)
except Exception as e:
logger.warning("Failed to publish job started event", error=str(e))
# Start training in background
background_tasks.add_task(
training_service.execute_training_job,
db,
job_id,
tenant_id,
request
job.job_id
)
# Publish training started event
await publish_job_started(job_id, tenant_id, request.dict())
logger.info("Training job created",
job_id=job.job_id,
tenant_id=tenant_id)
return TrainingJobResponse(
job_id=job_id,
status="started",
message="Training job started successfully",
tenant_id=tenant_id,
created_at=training_job.start_time,
estimated_duration_minutes=request.estimated_duration or 15
)
return job
except Exception as e:
logger.error(f"Failed to start training job: {str(e)}")
metrics.increment_counter("training_jobs_failed")
logger.error("Failed to start training job",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
@router.get("/jobs/{job_id}/status", response_model=TrainingStatusResponse)
async def get_training_status(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
@router.get("/jobs", response_model=List[TrainingJobResponse])
async def get_training_jobs(
status: Optional[TrainingJobStatus] = Query(None),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""
Get the status of a training job.
Provides real-time progress updates.
"""
"""Get training jobs for tenant"""
try:
# Get job status from database
job_status = await training_service.get_job_status(db, job_id, tenant_id)
logger.debug("Getting training jobs",
tenant_id=tenant_id,
status=status,
limit=limit,
offset=offset)
if not job_status:
raise HTTPException(status_code=404, detail="Training job not found")
return TrainingStatusResponse(
job_id=job_id,
status=job_status.status,
progress=job_status.progress,
current_step=job_status.current_step,
started_at=job_status.start_time,
completed_at=job_status.end_time,
results=job_status.results,
error_message=job_status.error_message
jobs = await training_service.get_training_jobs(
tenant_id=tenant_id,
status=status,
limit=limit,
offset=offset
)
logger.debug("Retrieved training jobs",
count=len(jobs),
tenant_id=tenant_id)
return jobs
except Exception as e:
logger.error("Failed to get training jobs",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}")
@router.get("/jobs/{job_id}", response_model=TrainingJobResponse)
async def get_training_job(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Get specific training job details"""
try:
logger.debug("Getting training job",
job_id=job_id,
tenant_id=tenant_id)
job = await training_service.get_training_job(job_id)
# Verify tenant access
if job.tenant_id != tenant_id:
logger.warning("Unauthorized job access attempt",
job_id=job_id,
tenant_id=tenant_id,
job_tenant_id=job.tenant_id)
raise HTTPException(status_code=404, detail="Job not found")
return job
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get training status: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get training status: {str(e)}")
logger.error("Failed to get training job",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to get training job: {str(e)}")
@router.get("/jobs/{job_id}/progress", response_model=TrainingJobProgress)
async def get_training_progress(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Get real-time training progress"""
try:
logger.debug("Getting training progress",
job_id=job_id,
tenant_id=tenant_id)
# Verify job belongs to tenant
job = await training_service.get_training_job(job_id)
if job.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Job not found")
progress = await training_service.get_job_progress(job_id)
return progress
except HTTPException:
raise
except Exception as e:
logger.error("Failed to get training progress",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to get training progress: {str(e)}")
@router.post("/jobs/{job_id}/cancel")
async def cancel_training_job(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Cancel a running training job"""
try:
logger.info("Cancelling training job",
job_id=job_id,
tenant_id=tenant_id,
user_id=current_user["user_id"])
job = await training_service.get_training_job(job_id)
# Verify tenant access
if job.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Job not found")
await training_service.cancel_training_job(job_id)
# Publish cancellation event
try:
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error="Job cancelled by user",
failed_at="cancellation"
)
except Exception as e:
logger.warning("Failed to publish cancellation event", error=str(e))
logger.info("Training job cancelled", job_id=job_id)
return {"message": "Job cancelled successfully", "job_id": job_id}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to cancel training job",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
@router.post("/products/{product_name}", response_model=TrainingJobResponse)
async def train_single_product(
product_name: str,
request: SingleProductTrainingRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""
Train a model for a single product.
Useful for quick model updates or new products.
"""
"""Train model for a single product"""
try:
logger.info(f"Starting single product training for {product_name}, tenant {tenant_id}")
metrics.increment_counter("single_product_training_started")
logger.info("Training single product",
product_name=product_name,
tenant_id=tenant_id,
user_id=current_user["user_id"])
# Generate job ID
job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
# Create training job record
training_job = await training_service.create_single_product_job(
db=db,
# Create training job for single product
job = await training_service.create_single_product_job(
tenant_id=tenant_id,
user_id=current_user["user_id"],
product_name=product_name,
job_id=job_id,
config=request.dict()
)
# Publish event
try:
await publish_product_training_started(
job_id=job.job_id,
tenant_id=tenant_id,
product_name=product_name
)
except Exception as e:
logger.warning("Failed to publish product training event", error=str(e))
# Start training in background
background_tasks.add_task(
training_service.execute_single_product_training,
db,
job_id,
tenant_id,
product_name,
request
job.job_id,
product_name
)
logger.info("Single product training started",
job_id=job.job_id,
product_name=product_name)
return job
except Exception as e:
logger.error("Failed to train single product",
error=str(e),
product_name=product_name,
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to train product: {str(e)}")
@router.post("/validate", response_model=DataValidationResponse)
async def validate_training_data(
request: DataValidationRequest,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Validate data before training"""
try:
logger.debug("Validating training data",
tenant_id=tenant_id,
products=request.products)
validation_result = await training_service.validate_training_data(
tenant_id=tenant_id,
products=request.products,
min_data_points=request.min_data_points
)
logger.debug("Data validation completed",
is_valid=validation_result.is_valid,
tenant_id=tenant_id)
return validation_result
except Exception as e:
logger.error("Failed to validate training data",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to validate data: {str(e)}")
@router.get("/models")
async def get_trained_models(
product_name: Optional[str] = Query(None),
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Get list of trained models"""
try:
logger.debug("Getting trained models",
tenant_id=tenant_id,
product_name=product_name)
models = await training_service.get_trained_models(
tenant_id=tenant_id,
product_name=product_name
)
logger.debug("Retrieved trained models",
count=len(models),
tenant_id=tenant_id)
return models
except Exception as e:
logger.error("Failed to get trained models",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get models: {str(e)}")
@router.delete("/models/{model_id}")
@require_role("admin") # Only admins can delete models
async def delete_model(
model_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Delete a trained model (admin only)"""
try:
logger.info("Deleting model",
model_id=model_id,
tenant_id=tenant_id,
admin_id=current_user["user_id"])
# Verify model belongs to tenant
model = await training_service.get_model(model_id)
if model.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Model not found")
success = await training_service.delete_model(model_id)
if not success:
raise HTTPException(status_code=404, detail="Model not found")
logger.info("Model deleted successfully", model_id=model_id)
return {"message": "Model deleted successfully", "model_id": model_id}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to delete model",
error=str(e),
model_id=model_id)
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
@router.get("/stats")
async def get_training_stats(
start_date: Optional[datetime] = Query(None),
end_date: Optional[datetime] = Query(None),
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Get training statistics for tenant"""
try:
logger.debug("Getting training stats",
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date)
stats = await training_service.get_training_stats(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
logger.debug("Training stats retrieved", tenant_id=tenant_id)
return stats
except Exception as e:
logger.error("Failed to get training stats",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}")
@router.post("/retrain/all")
async def retrain_all_products(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Retrain all products with existing models"""
try:
logger.info("Retraining all products",
tenant_id=tenant_id,
user_id=current_user["user_id"])
# Check if models exist
existing_models = await training_service.get_trained_models(tenant_id)
if not existing_models:
raise HTTPException(
status_code=400,
detail="No existing models found. Please run initial training first."
)
# Create retraining job
job = await training_service.create_training_job(
tenant_id=tenant_id,
user_id=current_user["user_id"],
config={**request.dict(), "is_retrain": True}
)
# Publish event
await publish_product_training_started(job_id, tenant_id, product_name)
return TrainingJobResponse(
job_id=job_id,
status="started",
message=f"Single product training started for {product_name}",
tenant_id=tenant_id,
created_at=training_job.start_time,
estimated_duration_minutes=5
)
except Exception as e:
logger.error(f"Failed to start single product training: {str(e)}")
metrics.increment_counter("single_product_training_failed")
raise HTTPException(status_code=500, detail=f"Failed to start training: {str(e)}")
@router.get("/jobs", response_model=List[TrainingStatusResponse])
async def list_training_jobs(
limit: int = 10,
status: Optional[str] = None,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
List training jobs for a tenant.
"""
try:
jobs = await training_service.list_training_jobs(
db=db,
tenant_id=tenant_id,
limit=limit,
status_filter=status
)
return [
TrainingStatusResponse(
try:
await publish_job_started(
job_id=job.job_id,
status=job.status,
progress=job.progress,
current_step=job.current_step,
started_at=job.start_time,
completed_at=job.end_time,
results=job.results,
error_message=job.error_message
tenant_id=tenant_id,
config={**request.dict(), "is_retrain": True}
)
for job in jobs
]
except Exception as e:
logger.warning("Failed to publish retrain event", error=str(e))
except Exception as e:
logger.error(f"Failed to list training jobs: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to list training jobs: {str(e)}")
@router.post("/jobs/{job_id}/cancel")
async def cancel_training_job(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Cancel a running training job.
"""
try:
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
# Update job status to cancelled
success = await training_service.cancel_training_job(db, job_id, tenant_id)
if not success:
raise HTTPException(status_code=404, detail="Training job not found or cannot be cancelled")
# Publish cancellation event
await publish_job_cancelled(job_id, tenant_id)
return {"message": "Training job cancelled successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to cancel training job: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
@router.get("/jobs/{job_id}/logs")
async def get_training_logs(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Get detailed logs for a training job.
"""
try:
logs = await training_service.get_training_logs(db, job_id, tenant_id)
if not logs:
raise HTTPException(status_code=404, detail="Training job not found")
return {"job_id": job_id, "logs": logs}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get training logs: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get training logs: {str(e)}")
@router.post("/validate")
async def validate_training_data(
request: TrainingJobRequest,
tenant_id: str = Depends(get_current_tenant_id),
db: AsyncSession = Depends(get_db)
):
"""
Validate training data before starting a job.
Provides early feedback on data quality issues.
"""
try:
logger.info(f"Validating training data for tenant {tenant_id}")
# Perform data validation
validation_result = await training_service.validate_training_data(
db=db,
tenant_id=tenant_id,
config=request.dict()
# Start retraining in background
background_tasks.add_task(
training_service.execute_training_job,
job.job_id
)
return {
"is_valid": validation_result["is_valid"],
"issues": validation_result.get("issues", []),
"recommendations": validation_result.get("recommendations", []),
"estimated_training_time": validation_result.get("estimated_time_minutes", 15)
}
logger.info("Retraining job created", job_id=job.job_id)
return job
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to validate training data: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to validate training data: {str(e)}")
@router.get("/health")
async def health_check():
"""Health check for the training service"""
return {
"status": "healthy",
"service": "training-service",
"timestamp": datetime.now().isoformat()
}
logger.error("Failed to start retraining",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to start retraining: {str(e)}")