Files
bakery-ia/services/training/app/api/training.py

472 lines
16 KiB
Python
Raw Normal View History

2025-07-20 07:24:04 +02:00
# ================================================================
# services/training/app/api/training.py - UPDATED WITH UNIFIED AUTH
# ================================================================
"""Training API endpoints with unified authentication"""
2025-07-20 07:24:04 +02:00
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query
from typing import List, Optional, Dict, Any
2025-07-19 16:59:37 +02:00
from datetime import datetime
2025-07-20 07:24:04 +02:00
import structlog
2025-07-19 16:59:37 +02:00
from app.schemas.training import (
2025-07-20 07:24:04 +02:00
TrainingJobRequest,
2025-07-19 16:59:37 +02:00
TrainingJobResponse,
2025-07-20 07:43:45 +02:00
TrainingStatus,
2025-07-20 07:24:04 +02:00
SingleProductTrainingRequest,
TrainingJobProgress,
DataValidationRequest,
DataValidationResponse
2025-07-19 16:59:37 +02:00
)
from app.services.training_service import TrainingService
2025-07-20 07:24:04 +02:00
from app.services.messaging import (
publish_job_started,
publish_job_completed,
publish_job_failed,
publish_job_progress,
publish_product_training_started,
publish_product_training_completed
)
2025-07-20 07:24:04 +02:00
# Import unified authentication from shared library
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep,
require_role
)
2025-07-20 07:24:04 +02:00
logger = structlog.get_logger()
router = APIRouter(prefix="/training", tags=["training"])
2025-07-19 16:59:37 +02:00
@router.post("/jobs", response_model=TrainingJobResponse)
async def start_training_job(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
2025-07-20 07:24:04 +02:00
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
2025-07-20 07:24:04 +02:00
"""Start a new training job for all products"""
try:
2025-07-20 07:24:04 +02:00
logger.info("Starting training job",
tenant_id=tenant_id,
user_id=current_user["user_id"],
config=request.dict())
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
# Create training job
job = await training_service.create_training_job(
2025-07-19 16:59:37 +02:00
tenant_id=tenant_id,
2025-07-20 07:24:04 +02:00
user_id=current_user["user_id"],
2025-07-19 16:59:37 +02:00
config=request.dict()
)
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
# Publish job started event
try:
await publish_job_started(
job_id=job.job_id,
tenant_id=tenant_id,
config=request.dict()
)
except Exception as e:
logger.warning("Failed to publish job started event", error=str(e))
2025-07-19 16:59:37 +02:00
# Start training in background
background_tasks.add_task(
training_service.execute_training_job,
2025-07-20 07:24:04 +02:00
job.job_id
2025-07-19 16:59:37 +02:00
)
2025-07-20 07:24:04 +02:00
logger.info("Training job created",
job_id=job.job_id,
tenant_id=tenant_id)
return job
except Exception as e:
logger.error("Failed to start training job",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
@router.get("/jobs", response_model=List[TrainingJobResponse])
async def get_training_jobs(
2025-07-20 07:43:45 +02:00
status: Optional[TrainingStatus] = Query(None, description="Filter jobs by status"),
2025-07-20 07:24:04 +02:00
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Get training jobs for tenant"""
try:
logger.debug("Getting training jobs",
tenant_id=tenant_id,
status=status,
limit=limit,
offset=offset)
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
jobs = await training_service.get_training_jobs(
2025-07-19 16:59:37 +02:00
tenant_id=tenant_id,
2025-07-20 07:24:04 +02:00
status=status,
limit=limit,
offset=offset
)
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
logger.debug("Retrieved training jobs",
count=len(jobs),
tenant_id=tenant_id)
return jobs
2025-07-19 16:59:37 +02:00
except Exception as e:
2025-07-20 07:24:04 +02:00
logger.error("Failed to get training jobs",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}")
2025-07-20 07:24:04 +02:00
@router.get("/jobs/{job_id}", response_model=TrainingJobResponse)
async def get_training_job(
job_id: str,
2025-07-20 07:24:04 +02:00
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
2025-07-20 07:24:04 +02:00
"""Get specific training job details"""
try:
2025-07-20 07:24:04 +02:00
logger.debug("Getting training job",
job_id=job_id,
tenant_id=tenant_id)
job = await training_service.get_training_job(job_id)
# Verify tenant access
if job.tenant_id != tenant_id:
logger.warning("Unauthorized job access attempt",
job_id=job_id,
tenant_id=tenant_id,
job_tenant_id=job.tenant_id)
raise HTTPException(status_code=404, detail="Job not found")
return job
2025-07-19 16:59:37 +02:00
except HTTPException:
raise
except Exception as e:
2025-07-20 07:24:04 +02:00
logger.error("Failed to get training job",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to get training job: {str(e)}")
@router.get("/jobs/{job_id}/progress", response_model=TrainingJobProgress)
async def get_training_progress(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Get real-time training progress"""
try:
logger.debug("Getting training progress",
job_id=job_id,
tenant_id=tenant_id)
# Verify job belongs to tenant
job = await training_service.get_training_job(job_id)
if job.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Job not found")
progress = await training_service.get_job_progress(job_id)
return progress
except HTTPException:
raise
except Exception as e:
logger.error("Failed to get training progress",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to get training progress: {str(e)}")
@router.post("/jobs/{job_id}/cancel")
async def cancel_training_job(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Cancel a running training job"""
try:
logger.info("Cancelling training job",
job_id=job_id,
tenant_id=tenant_id,
user_id=current_user["user_id"])
job = await training_service.get_training_job(job_id)
# Verify tenant access
if job.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Job not found")
await training_service.cancel_training_job(job_id)
# Publish cancellation event
try:
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error="Job cancelled by user",
failed_at="cancellation"
)
except Exception as e:
logger.warning("Failed to publish cancellation event", error=str(e))
logger.info("Training job cancelled", job_id=job_id)
return {"message": "Job cancelled successfully", "job_id": job_id}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to cancel training job",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
2025-07-19 16:59:37 +02:00
@router.post("/products/{product_name}", response_model=TrainingJobResponse)
async def train_single_product(
product_name: str,
request: SingleProductTrainingRequest,
background_tasks: BackgroundTasks,
2025-07-20 07:24:04 +02:00
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
2025-07-19 16:59:37 +02:00
):
2025-07-20 07:24:04 +02:00
"""Train model for a single product"""
2025-07-19 16:59:37 +02:00
try:
2025-07-20 07:24:04 +02:00
logger.info("Training single product",
product_name=product_name,
tenant_id=tenant_id,
user_id=current_user["user_id"])
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
# Create training job for single product
job = await training_service.create_single_product_job(
2025-07-19 16:59:37 +02:00
tenant_id=tenant_id,
2025-07-20 07:24:04 +02:00
user_id=current_user["user_id"],
2025-07-19 16:59:37 +02:00
product_name=product_name,
config=request.dict()
)
2025-07-20 07:24:04 +02:00
# Publish event
try:
await publish_product_training_started(
job_id=job.job_id,
tenant_id=tenant_id,
product_name=product_name
)
except Exception as e:
logger.warning("Failed to publish product training event", error=str(e))
2025-07-19 16:59:37 +02:00
# Start training in background
background_tasks.add_task(
training_service.execute_single_product_training,
2025-07-20 07:24:04 +02:00
job.job_id,
product_name
2025-07-19 16:59:37 +02:00
)
2025-07-20 07:24:04 +02:00
logger.info("Single product training started",
job_id=job.job_id,
product_name=product_name)
return job
except Exception as e:
logger.error("Failed to train single product",
error=str(e),
product_name=product_name,
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to train product: {str(e)}")
@router.post("/validate", response_model=DataValidationResponse)
async def validate_training_data(
request: DataValidationRequest,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
"""Validate data before training"""
try:
logger.debug("Validating training data",
tenant_id=tenant_id,
products=request.products)
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
validation_result = await training_service.validate_training_data(
2025-07-19 16:59:37 +02:00
tenant_id=tenant_id,
2025-07-20 07:24:04 +02:00
products=request.products,
min_data_points=request.min_data_points
2025-07-19 16:59:37 +02:00
)
2025-07-20 07:24:04 +02:00
logger.debug("Data validation completed",
is_valid=validation_result.is_valid,
tenant_id=tenant_id)
return validation_result
2025-07-19 16:59:37 +02:00
except Exception as e:
2025-07-20 07:24:04 +02:00
logger.error("Failed to validate training data",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to validate data: {str(e)}")
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
@router.get("/models")
async def get_trained_models(
product_name: Optional[str] = Query(None),
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
2025-07-19 16:59:37 +02:00
):
2025-07-20 07:24:04 +02:00
"""Get list of trained models"""
2025-07-19 16:59:37 +02:00
try:
2025-07-20 07:24:04 +02:00
logger.debug("Getting trained models",
tenant_id=tenant_id,
product_name=product_name)
models = await training_service.get_trained_models(
2025-07-19 16:59:37 +02:00
tenant_id=tenant_id,
2025-07-20 07:24:04 +02:00
product_name=product_name
)
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
logger.debug("Retrieved trained models",
count=len(models),
tenant_id=tenant_id)
return models
2025-07-19 16:59:37 +02:00
except Exception as e:
2025-07-20 07:24:04 +02:00
logger.error("Failed to get trained models",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get models: {str(e)}")
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
@router.delete("/models/{model_id}")
@require_role("admin") # Only admins can delete models
async def delete_model(
model_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
2025-07-19 16:59:37 +02:00
):
2025-07-20 07:24:04 +02:00
"""Delete a trained model (admin only)"""
2025-07-19 16:59:37 +02:00
try:
2025-07-20 07:24:04 +02:00
logger.info("Deleting model",
model_id=model_id,
tenant_id=tenant_id,
admin_id=current_user["user_id"])
# Verify model belongs to tenant
model = await training_service.get_model(model_id)
if model.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Model not found")
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
success = await training_service.delete_model(model_id)
2025-07-19 16:59:37 +02:00
if not success:
2025-07-20 07:24:04 +02:00
raise HTTPException(status_code=404, detail="Model not found")
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
logger.info("Model deleted successfully", model_id=model_id)
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
return {"message": "Model deleted successfully", "model_id": model_id}
2025-07-19 16:59:37 +02:00
except HTTPException:
raise
except Exception as e:
2025-07-20 07:24:04 +02:00
logger.error("Failed to delete model",
error=str(e),
model_id=model_id)
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
2025-07-20 07:24:04 +02:00
@router.get("/stats")
async def get_training_stats(
start_date: Optional[datetime] = Query(None),
end_date: Optional[datetime] = Query(None),
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
):
2025-07-20 07:24:04 +02:00
"""Get training statistics for tenant"""
try:
2025-07-20 07:24:04 +02:00
logger.debug("Getting training stats",
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date)
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
stats = await training_service.get_training_stats(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
logger.debug("Training stats retrieved", tenant_id=tenant_id)
return stats
2025-07-19 16:59:37 +02:00
except Exception as e:
2025-07-20 07:24:04 +02:00
logger.error("Failed to get training stats",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}")
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
@router.post("/retrain/all")
async def retrain_all_products(
2025-07-19 16:59:37 +02:00
request: TrainingJobRequest,
2025-07-20 07:24:04 +02:00
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends()
2025-07-19 16:59:37 +02:00
):
2025-07-20 07:24:04 +02:00
"""Retrain all products with existing models"""
2025-07-19 16:59:37 +02:00
try:
2025-07-20 07:24:04 +02:00
logger.info("Retraining all products",
tenant_id=tenant_id,
user_id=current_user["user_id"])
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
# Check if models exist
existing_models = await training_service.get_trained_models(tenant_id)
if not existing_models:
raise HTTPException(
status_code=400,
detail="No existing models found. Please run initial training first."
)
# Create retraining job
job = await training_service.create_training_job(
2025-07-19 16:59:37 +02:00
tenant_id=tenant_id,
2025-07-20 07:24:04 +02:00
user_id=current_user["user_id"],
config={**request.dict(), "is_retrain": True}
2025-07-19 16:59:37 +02:00
)
2025-07-20 07:24:04 +02:00
# Publish event
try:
await publish_job_started(
job_id=job.job_id,
tenant_id=tenant_id,
config={**request.dict(), "is_retrain": True}
)
except Exception as e:
logger.warning("Failed to publish retrain event", error=str(e))
# Start retraining in background
background_tasks.add_task(
training_service.execute_training_job,
job.job_id
)
logger.info("Retraining job created", job_id=job.job_id)
return job
2025-07-19 16:59:37 +02:00
2025-07-20 07:24:04 +02:00
except HTTPException:
raise
2025-07-19 16:59:37 +02:00
except Exception as e:
2025-07-20 07:24:04 +02:00
logger.error("Failed to start retraining",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to start retraining: {str(e)}")