496 lines
17 KiB
Python
496 lines
17 KiB
Python
# ================================================================
|
|
# services/training/app/api/training.py - UPDATED WITH UNIFIED AUTH
|
|
# ================================================================
|
|
"""Training API endpoints with unified authentication"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query
|
|
from typing import List, Optional, Dict, Any
|
|
from datetime import datetime
|
|
import structlog
|
|
import uuid
|
|
|
|
from app.schemas.training import (
|
|
TrainingJobRequest,
|
|
TrainingJobResponse,
|
|
TrainingStatus,
|
|
SingleProductTrainingRequest,
|
|
TrainingJobProgress,
|
|
DataValidationRequest,
|
|
DataValidationResponse
|
|
)
|
|
from app.services.training_service import TrainingService
|
|
from app.services.messaging import (
|
|
publish_job_started,
|
|
publish_job_completed,
|
|
publish_job_failed,
|
|
publish_job_progress,
|
|
publish_product_training_started,
|
|
publish_product_training_completed
|
|
)
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from app.core.database import get_db_session
|
|
|
|
# Import unified authentication from shared library
|
|
from shared.auth.decorators import (
|
|
get_current_user_dep,
|
|
get_current_tenant_id_dep,
|
|
require_role
|
|
)
|
|
|
|
logger = structlog.get_logger()
|
|
router = APIRouter(prefix="/training", tags=["training"])
|
|
|
|
def get_training_service() -> TrainingService:
|
|
"""Factory function for TrainingService dependency"""
|
|
return TrainingService()
|
|
|
|
@router.post("/jobs", response_model=TrainingJobResponse)
|
|
async def start_training_job(
|
|
request: TrainingJobRequest,
|
|
background_tasks: BackgroundTasks,
|
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
training_service: TrainingService = Depends(get_training_service),
|
|
db: AsyncSession = Depends(get_db_session) # Ensure db is available
|
|
):
|
|
"""Start a new training job for all products"""
|
|
try:
|
|
|
|
new_job_id = str(uuid.uuid4())
|
|
|
|
logger.info("Starting training job",
|
|
tenant_id=tenant_id,
|
|
job_id=uuid.uuid4(),
|
|
config=request.dict())
|
|
|
|
# Create training job
|
|
job = await training_service.create_training_job(
|
|
db, # Pass db here
|
|
tenant_id=tenant_id,
|
|
job_id=new_job_id,
|
|
config=request.dict()
|
|
)
|
|
|
|
# Publish job started event
|
|
try:
|
|
await publish_job_started(
|
|
job_id=new_job_id,
|
|
tenant_id=tenant_id,
|
|
config=request.dict()
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Failed to publish job started event", error=str(e))
|
|
|
|
# Start training in background
|
|
background_tasks.add_task(
|
|
training_service.execute_training_job,
|
|
db, # Pass the database session
|
|
job.job_id,
|
|
job.tenant_id,
|
|
request # Pass the request object
|
|
)
|
|
|
|
logger.info("Training job created",
|
|
job_id=job.job_id,
|
|
tenant_id=tenant_id)
|
|
|
|
return TrainingJobResponse(
|
|
job_id=job.job_id,
|
|
status=TrainingStatus.PENDING,
|
|
message="Training job created successfully",
|
|
tenant_id=tenant_id,
|
|
created_at=job.created_at,
|
|
estimated_duration_minutes=30
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to start training job",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
|
|
|
|
@router.get("/jobs", response_model=List[TrainingJobResponse])
|
|
async def get_training_jobs(
|
|
status: Optional[TrainingStatus] = Query(None, description="Filter jobs by status"),
|
|
limit: int = Query(100, ge=1, le=1000),
|
|
offset: int = Query(0, ge=0),
|
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
training_service: TrainingService = Depends(get_training_service)
|
|
):
|
|
"""Get training jobs for tenant"""
|
|
try:
|
|
logger.debug("Getting training jobs",
|
|
tenant_id=tenant_id,
|
|
status=status,
|
|
limit=limit,
|
|
offset=offset)
|
|
|
|
jobs = await training_service.get_training_jobs(
|
|
tenant_id=tenant_id,
|
|
status=status,
|
|
limit=limit,
|
|
offset=offset
|
|
)
|
|
|
|
logger.debug("Retrieved training jobs",
|
|
count=len(jobs),
|
|
tenant_id=tenant_id)
|
|
|
|
return jobs
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get training jobs",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}")
|
|
|
|
@router.get("/jobs/{job_id}", response_model=TrainingJobResponse)
|
|
async def get_training_job(
|
|
job_id: str,
|
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
training_service: TrainingService = Depends(get_training_service)
|
|
):
|
|
"""Get specific training job details"""
|
|
try:
|
|
logger.debug("Getting training job",
|
|
job_id=job_id,
|
|
tenant_id=tenant_id)
|
|
|
|
job = await training_service.get_training_job(job_id)
|
|
|
|
# Verify tenant access
|
|
if job.tenant_id != tenant_id:
|
|
logger.warning("Unauthorized job access attempt",
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
job_tenant_id=job.tenant_id)
|
|
raise HTTPException(status_code=404, detail="Job not found")
|
|
|
|
return job
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to get training job",
|
|
error=str(e),
|
|
job_id=job_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to get training job: {str(e)}")
|
|
|
|
@router.get("/jobs/{job_id}/progress", response_model=TrainingJobProgress)
|
|
async def get_training_progress(
|
|
job_id: str,
|
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
training_service: TrainingService = Depends(get_training_service)
|
|
):
|
|
"""Get real-time training progress"""
|
|
try:
|
|
logger.debug("Getting training progress",
|
|
job_id=job_id,
|
|
tenant_id=tenant_id)
|
|
|
|
# Verify job belongs to tenant
|
|
job = await training_service.get_training_job(job_id)
|
|
if job.tenant_id != tenant_id:
|
|
raise HTTPException(status_code=404, detail="Job not found")
|
|
|
|
progress = await training_service.get_job_progress(job_id)
|
|
|
|
return progress
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to get training progress",
|
|
error=str(e),
|
|
job_id=job_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to get training progress: {str(e)}")
|
|
|
|
@router.post("/jobs/{job_id}/cancel")
|
|
async def cancel_training_job(
|
|
job_id: str,
|
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
training_service: TrainingService = Depends(get_training_service)
|
|
):
|
|
"""Cancel a running training job"""
|
|
try:
|
|
logger.info("Cancelling training job",
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
user_id=current_user["user_id"])
|
|
|
|
job = await training_service.get_training_job(job_id)
|
|
|
|
# Verify tenant access
|
|
if job.tenant_id != tenant_id:
|
|
raise HTTPException(status_code=404, detail="Job not found")
|
|
|
|
await training_service.cancel_training_job(job_id)
|
|
|
|
# Publish cancellation event
|
|
try:
|
|
await publish_job_failed(
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
error="Job cancelled by user",
|
|
failed_at="cancellation"
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Failed to publish cancellation event", error=str(e))
|
|
|
|
logger.info("Training job cancelled", job_id=job_id)
|
|
|
|
return {"message": "Job cancelled successfully", "job_id": job_id}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to cancel training job",
|
|
error=str(e),
|
|
job_id=job_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
|
|
|
|
@router.post("/products/{product_name}", response_model=TrainingJobResponse)
|
|
async def train_single_product(
|
|
product_name: str,
|
|
request: SingleProductTrainingRequest,
|
|
background_tasks: BackgroundTasks,
|
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
training_service: TrainingService = Depends(get_training_service),
|
|
db: AsyncSession = Depends(get_db_session)
|
|
):
|
|
"""Train model for a single product"""
|
|
try:
|
|
logger.info("Training single product",
|
|
product_name=product_name,
|
|
tenant_id=tenant_id,
|
|
user_id=current_user["user_id"])
|
|
|
|
# Create training job for single product
|
|
job = await training_service.create_single_product_job(
|
|
db,
|
|
tenant_id=tenant_id,
|
|
product_name=product_name,
|
|
config=request.dict()
|
|
)
|
|
|
|
# Publish event
|
|
try:
|
|
await publish_product_training_started(
|
|
job_id=job.job_id,
|
|
tenant_id=tenant_id,
|
|
product_name=product_name
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Failed to publish product training event", error=str(e))
|
|
|
|
# Start training in background
|
|
background_tasks.add_task(
|
|
training_service.execute_single_product_training,
|
|
job.job_id,
|
|
product_name
|
|
)
|
|
|
|
logger.info("Single product training started",
|
|
job_id=job.job_id,
|
|
product_name=product_name)
|
|
|
|
return job
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to train single product",
|
|
error=str(e),
|
|
product_name=product_name,
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to train product: {str(e)}")
|
|
|
|
@router.post("/validate", response_model=DataValidationResponse)
|
|
async def validate_training_data(
|
|
request: DataValidationRequest,
|
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
training_service: TrainingService = Depends(get_training_service)
|
|
):
|
|
"""Validate data before training"""
|
|
try:
|
|
logger.debug("Validating training data",
|
|
tenant_id=tenant_id,
|
|
products=request.products)
|
|
|
|
validation_result = await training_service.validate_training_data(
|
|
tenant_id=tenant_id,
|
|
products=request.products,
|
|
min_data_points=request.min_data_points
|
|
)
|
|
|
|
logger.debug("Data validation completed",
|
|
is_valid=validation_result.is_valid,
|
|
tenant_id=tenant_id)
|
|
|
|
return validation_result
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to validate training data",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to validate data: {str(e)}")
|
|
|
|
@router.get("/models")
|
|
async def get_trained_models(
|
|
product_name: Optional[str] = Query(None),
|
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
training_service: TrainingService = Depends(get_training_service)
|
|
):
|
|
"""Get list of trained models"""
|
|
try:
|
|
logger.debug("Getting trained models",
|
|
tenant_id=tenant_id,
|
|
product_name=product_name)
|
|
|
|
models = await training_service.get_trained_models(
|
|
tenant_id=tenant_id,
|
|
product_name=product_name
|
|
)
|
|
|
|
logger.debug("Retrieved trained models",
|
|
count=len(models),
|
|
tenant_id=tenant_id)
|
|
|
|
return models
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get trained models",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to get models: {str(e)}")
|
|
|
|
@router.delete("/models/{model_id}")
|
|
@require_role("admin") # Only admins can delete models
|
|
async def delete_model(
|
|
model_id: str,
|
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
training_service: TrainingService = Depends(get_training_service)
|
|
):
|
|
"""Delete a trained model (admin only)"""
|
|
try:
|
|
logger.info("Deleting model",
|
|
model_id=model_id,
|
|
tenant_id=tenant_id,
|
|
admin_id=current_user["user_id"])
|
|
|
|
# Verify model belongs to tenant
|
|
model = await training_service.get_model(model_id)
|
|
if model.tenant_id != tenant_id:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
|
|
success = await training_service.delete_model(model_id)
|
|
|
|
if not success:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
|
|
logger.info("Model deleted successfully", model_id=model_id)
|
|
|
|
return {"message": "Model deleted successfully", "model_id": model_id}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to delete model",
|
|
error=str(e),
|
|
model_id=model_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
|
|
|
|
@router.get("/stats")
|
|
async def get_training_stats(
|
|
start_date: Optional[datetime] = Query(None),
|
|
end_date: Optional[datetime] = Query(None),
|
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
training_service: TrainingService = Depends(get_training_service)
|
|
):
|
|
"""Get training statistics for tenant"""
|
|
try:
|
|
logger.debug("Getting training stats",
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date)
|
|
|
|
stats = await training_service.get_training_stats(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date
|
|
)
|
|
|
|
logger.debug("Training stats retrieved", tenant_id=tenant_id)
|
|
|
|
return stats
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get training stats",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}")
|
|
|
|
@router.post("/retrain/all")
|
|
async def retrain_all_products(
|
|
request: TrainingJobRequest,
|
|
background_tasks: BackgroundTasks,
|
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
training_service: TrainingService = Depends(get_training_service)
|
|
):
|
|
"""Retrain all products with existing models"""
|
|
try:
|
|
logger.info("Retraining all products",
|
|
tenant_id=tenant_id,
|
|
user_id=current_user["user_id"])
|
|
|
|
# Check if models exist
|
|
existing_models = await training_service.get_trained_models(tenant_id)
|
|
if not existing_models:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="No existing models found. Please run initial training first."
|
|
)
|
|
|
|
# Create retraining job
|
|
job = await training_service.create_training_job(
|
|
tenant_id=tenant_id,
|
|
user_id=current_user["user_id"],
|
|
config={**request.dict(), "is_retrain": True}
|
|
)
|
|
|
|
# Publish event
|
|
try:
|
|
await publish_job_started(
|
|
job_id=job.job_id,
|
|
tenant_id=tenant_id,
|
|
config={**request.dict(), "is_retrain": True}
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Failed to publish retrain event", error=str(e))
|
|
|
|
# Start retraining in background
|
|
background_tasks.add_task(
|
|
training_service.execute_training_job,
|
|
job.job_id
|
|
)
|
|
|
|
logger.info("Retraining job created", job_id=job.job_id)
|
|
|
|
return job
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to start retraining",
|
|
error=str(e),
|
|
tenant_id=tenant_id)
|
|
raise HTTPException(status_code=500, detail=f"Failed to start retraining: {str(e)}") |