Files
bakery-ia/services/training/app/api/training.py
2025-07-25 19:40:49 +02:00

489 lines
17 KiB
Python

# ================================================================
# services/training/app/api/training.py - UPDATED WITH UNIFIED AUTH
# ================================================================
"""Training API endpoints with unified authentication"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query
from typing import List, Optional, Dict, Any
from datetime import datetime
import structlog
import uuid
from app.schemas.training import (
TrainingJobRequest,
TrainingJobResponse,
TrainingStatus,
SingleProductTrainingRequest,
TrainingJobProgress,
DataValidationRequest,
DataValidationResponse
)
from app.services.training_service import TrainingService
from app.services.messaging import (
publish_job_started,
publish_job_completed,
publish_job_failed,
publish_job_progress,
publish_product_training_started,
publish_product_training_completed
)
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db_session
# Import unified authentication from shared library
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep,
require_role
)
logger = structlog.get_logger()
router = APIRouter(prefix="/training", tags=["training"])
def get_training_service() -> TrainingService:
"""Factory function for TrainingService dependency"""
return TrainingService()
@router.post("/jobs", response_model=TrainingJobResponse)
async def start_training_job(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service),
db: AsyncSession = Depends(get_db_session) # Ensure db is available
):
"""Start a new training job for all products"""
try:
new_job_id = str(uuid.uuid4())
logger.info("Starting training job",
tenant_id=tenant_id,
job_id=uuid.uuid4(),
config=request.dict())
# Create training job
job = await training_service.create_training_job(
db, # Pass db here
tenant_id=tenant_id,
job_id=new_job_id,
config=request.dict()
)
# Publish job started event
try:
await publish_job_started(
job_id=new_job_id,
tenant_id=tenant_id,
config=request.dict()
)
except Exception as e:
logger.warning("Failed to publish job started event", error=str(e))
# Start training in background
background_tasks.add_task(
training_service.execute_training_job,
db, # Pass the database session
job.job_id,
job.tenant_id,
request # Pass the request object
)
logger.info("Training job created",
job_id=job.job_id,
tenant_id=tenant_id)
return job
except Exception as e:
logger.error("Failed to start training job",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
@router.get("/jobs", response_model=List[TrainingJobResponse])
async def get_training_jobs(
status: Optional[TrainingStatus] = Query(None, description="Filter jobs by status"),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Get training jobs for tenant"""
try:
logger.debug("Getting training jobs",
tenant_id=tenant_id,
status=status,
limit=limit,
offset=offset)
jobs = await training_service.get_training_jobs(
tenant_id=tenant_id,
status=status,
limit=limit,
offset=offset
)
logger.debug("Retrieved training jobs",
count=len(jobs),
tenant_id=tenant_id)
return jobs
except Exception as e:
logger.error("Failed to get training jobs",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}")
@router.get("/jobs/{job_id}", response_model=TrainingJobResponse)
async def get_training_job(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Get specific training job details"""
try:
logger.debug("Getting training job",
job_id=job_id,
tenant_id=tenant_id)
job = await training_service.get_training_job(job_id)
# Verify tenant access
if job.tenant_id != tenant_id:
logger.warning("Unauthorized job access attempt",
job_id=job_id,
tenant_id=tenant_id,
job_tenant_id=job.tenant_id)
raise HTTPException(status_code=404, detail="Job not found")
return job
except HTTPException:
raise
except Exception as e:
logger.error("Failed to get training job",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to get training job: {str(e)}")
@router.get("/jobs/{job_id}/progress", response_model=TrainingJobProgress)
async def get_training_progress(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Get real-time training progress"""
try:
logger.debug("Getting training progress",
job_id=job_id,
tenant_id=tenant_id)
# Verify job belongs to tenant
job = await training_service.get_training_job(job_id)
if job.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Job not found")
progress = await training_service.get_job_progress(job_id)
return progress
except HTTPException:
raise
except Exception as e:
logger.error("Failed to get training progress",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to get training progress: {str(e)}")
@router.post("/jobs/{job_id}/cancel")
async def cancel_training_job(
job_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Cancel a running training job"""
try:
logger.info("Cancelling training job",
job_id=job_id,
tenant_id=tenant_id,
user_id=current_user["user_id"])
job = await training_service.get_training_job(job_id)
# Verify tenant access
if job.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Job not found")
await training_service.cancel_training_job(job_id)
# Publish cancellation event
try:
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error="Job cancelled by user",
failed_at="cancellation"
)
except Exception as e:
logger.warning("Failed to publish cancellation event", error=str(e))
logger.info("Training job cancelled", job_id=job_id)
return {"message": "Job cancelled successfully", "job_id": job_id}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to cancel training job",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
@router.post("/products/{product_name}", response_model=TrainingJobResponse)
async def train_single_product(
product_name: str,
request: SingleProductTrainingRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service),
db: AsyncSession = Depends(get_db_session)
):
"""Train model for a single product"""
try:
logger.info("Training single product",
product_name=product_name,
tenant_id=tenant_id,
user_id=current_user["user_id"])
# Create training job for single product
job = await training_service.create_single_product_job(
db,
tenant_id=tenant_id,
product_name=product_name,
config=request.dict()
)
# Publish event
try:
await publish_product_training_started(
job_id=job.job_id,
tenant_id=tenant_id,
product_name=product_name
)
except Exception as e:
logger.warning("Failed to publish product training event", error=str(e))
# Start training in background
background_tasks.add_task(
training_service.execute_single_product_training,
job.job_id,
product_name
)
logger.info("Single product training started",
job_id=job.job_id,
product_name=product_name)
return job
except Exception as e:
logger.error("Failed to train single product",
error=str(e),
product_name=product_name,
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to train product: {str(e)}")
@router.post("/validate", response_model=DataValidationResponse)
async def validate_training_data(
request: DataValidationRequest,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Validate data before training"""
try:
logger.debug("Validating training data",
tenant_id=tenant_id,
products=request.products)
validation_result = await training_service.validate_training_data(
tenant_id=tenant_id,
products=request.products,
min_data_points=request.min_data_points
)
logger.debug("Data validation completed",
is_valid=validation_result.is_valid,
tenant_id=tenant_id)
return validation_result
except Exception as e:
logger.error("Failed to validate training data",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to validate data: {str(e)}")
@router.get("/models")
async def get_trained_models(
product_name: Optional[str] = Query(None),
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Get list of trained models"""
try:
logger.debug("Getting trained models",
tenant_id=tenant_id,
product_name=product_name)
models = await training_service.get_trained_models(
tenant_id=tenant_id,
product_name=product_name
)
logger.debug("Retrieved trained models",
count=len(models),
tenant_id=tenant_id)
return models
except Exception as e:
logger.error("Failed to get trained models",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get models: {str(e)}")
@router.delete("/models/{model_id}")
@require_role("admin") # Only admins can delete models
async def delete_model(
model_id: str,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Delete a trained model (admin only)"""
try:
logger.info("Deleting model",
model_id=model_id,
tenant_id=tenant_id,
admin_id=current_user["user_id"])
# Verify model belongs to tenant
model = await training_service.get_model(model_id)
if model.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Model not found")
success = await training_service.delete_model(model_id)
if not success:
raise HTTPException(status_code=404, detail="Model not found")
logger.info("Model deleted successfully", model_id=model_id)
return {"message": "Model deleted successfully", "model_id": model_id}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to delete model",
error=str(e),
model_id=model_id)
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
@router.get("/stats")
async def get_training_stats(
start_date: Optional[datetime] = Query(None),
end_date: Optional[datetime] = Query(None),
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Get training statistics for tenant"""
try:
logger.debug("Getting training stats",
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date)
stats = await training_service.get_training_stats(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
logger.debug("Training stats retrieved", tenant_id=tenant_id)
return stats
except Exception as e:
logger.error("Failed to get training stats",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}")
@router.post("/retrain/all")
async def retrain_all_products(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
tenant_id: str = Depends(get_current_tenant_id_dep),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Retrain all products with existing models"""
try:
logger.info("Retraining all products",
tenant_id=tenant_id,
user_id=current_user["user_id"])
# Check if models exist
existing_models = await training_service.get_trained_models(tenant_id)
if not existing_models:
raise HTTPException(
status_code=400,
detail="No existing models found. Please run initial training first."
)
# Create retraining job
job = await training_service.create_training_job(
tenant_id=tenant_id,
user_id=current_user["user_id"],
config={**request.dict(), "is_retrain": True}
)
# Publish event
try:
await publish_job_started(
job_id=job.job_id,
tenant_id=tenant_id,
config={**request.dict(), "is_retrain": True}
)
except Exception as e:
logger.warning("Failed to publish retrain event", error=str(e))
# Start retraining in background
background_tasks.add_task(
training_service.execute_training_job,
job.job_id
)
logger.info("Retraining job created", job_id=job.job_id)
return job
except HTTPException:
raise
except Exception as e:
logger.error("Failed to start retraining",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to start retraining: {str(e)}")