Improve training code

This commit is contained in:
Urtzi Alfaro
2025-07-28 19:28:39 +02:00
parent 946015b80c
commit 98f546af12
15 changed files with 2534 additions and 2812 deletions

View File

@@ -1,539 +1,209 @@
# ================================================================
# services/training/app/api/training.py - UPDATED WITH UNIFIED AUTH
# ================================================================
"""Training API endpoints with unified authentication"""
# services/training/app/api/training.py
"""
Training API Endpoints - Entry point for training requests
Handles HTTP requests and delegates to Training Service
"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query, Path
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
from fastapi import Query, Path
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional, Dict, Any
from datetime import datetime
import structlog
from uuid import UUID, uuid4
from datetime import datetime
from app.core.database import get_db
from app.services.training_service import TrainingService
from app.schemas.training import (
TrainingJobRequest,
TrainingJobResponse,
TrainingStatus,
SingleProductTrainingRequest,
TrainingJobProgress,
DataValidationRequest,
DataValidationResponse
SingleProductTrainingRequest
)
from app.services.training_service import TrainingService
from app.services.messaging import (
publish_job_started,
publish_job_completed,
publish_job_failed,
publish_job_progress,
publish_product_training_started,
publish_product_training_completed
from app.schemas.training import (
TrainingJobResponse
)
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db_session
# Import unified authentication from shared library
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep,
require_role
)
# Import shared auth decorators (assuming they exist in your microservices)
from shared.auth.decorators import get_current_tenant_id_dep
logger = structlog.get_logger()
router = APIRouter(tags=["training"])
router = APIRouter()
def get_training_service() -> TrainingService:
"""Factory function for TrainingService dependency"""
return TrainingService()
# Initialize training service
training_service = TrainingService()
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
async def start_training_job(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service),
db: AsyncSession = Depends(get_db_session) # Ensure db is available
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
):
"""Start a new training job for all products"""
"""
Start a new training job for all tenant products.
This is the main entry point for the training pipeline:
API → Training Service → Trainer → Data Processor → Prophet Manager
"""
try:
tenant_id_str = str(tenant_id)
new_job_id = str(uuid4())
logger.info("Starting training job",
tenant_id=tenant_id_str,
job_id=uuid4(),
config=request.dict())
# Create training job
job = await training_service.create_training_job(
db, # Pass db here
tenant_id=tenant_id_str,
job_id=new_job_id,
config=request.dict()
)
# Publish job started event
try:
await publish_job_started(
job_id=new_job_id,
tenant_id=tenant_id_str,
config=request.dict()
# Validate tenant access
if tenant_id != current_tenant:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
except Exception as e:
logger.warning("Failed to publish job started event", error=str(e))
background_tasks.add_task(
training_service.execute_training_job_simple,
new_job_id,
tenant_id_str,
request
)
logger.info("Training job created",
job_id=job.job_id,
tenant_id=tenant_id)
return TrainingJobResponse(
job_id=job.job_id,
status=TrainingStatus.PENDING,
message="Training job created successfully",
tenant_id=tenant_id_str,
created_at=job.created_at,
estimated_duration_minutes=30
)
except Exception as e:
logger.error("Failed to start training job",
error=str(e),
tenant_id=str(tenant_id))
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
@router.get("/tenants/{tenant_id}/training/jobs", response_model=List[TrainingJobResponse])
async def get_training_jobs(
status: Optional[TrainingStatus] = Query(None, description="Filter jobs by status"),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Get training jobs for tenant"""
try:
tenant_id_str = str(tenant_id)
logger.info(f"Starting training job for tenant {tenant_id}")
logger.debug("Getting training jobs",
tenant_id=tenant_id_str,
status=status,
limit=limit,
offset=offset)
training_service = TrainingService(db_session=db)
jobs = await training_service.get_training_jobs(
tenant_id=tenant_id_str,
status=status,
limit=limit,
offset=offset
# Delegate to training service (Step 1 of the flow)
result = await training_service.start_training_job(
tenant_id=tenant_id,
bakery_location=request.bakery_location or (40.4168, -3.7038), # Default Madrid
requested_start=request.start_date if request.start_date else None,
requested_end=request.end_date if request.end_date else None,
job_id=request.job_id
)
logger.debug("Retrieved training jobs",
count=len(jobs),
tenant_id=tenant_id_str)
return TrainingJobResponse(**result)
return jobs
except Exception as e:
logger.error("Failed to get training jobs",
error=str(e),
tenant_id=str(tenant_id))
raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}")
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}", response_model=TrainingJobResponse)
async def get_training_job(
job_id: str,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service),
db: AsyncSession = Depends(get_db_session)
):
"""Get specific training job details"""
try:
tenant_id_str = str(tenant_id)
logger.debug("Getting training job",
job_id=job_id,
tenant_id=tenant_id_str)
job_log = await training_service.get_job_status(db, job_id, tenant_id_str)
# Verify tenant access
if job_log.tenant_id != tenant_id:
logger.warning("Unauthorized job access attempt",
job_id=job_id,
tenant_id=str(tenant_id),
job_tenant_id=job.tenant_id)
raise HTTPException(status_code=404, detail="Job not found")
return TrainingJobResponse(
job_id=job_log.job_id,
status=TrainingStatus(job_log.status),
message=_generate_status_message(job_log.status, job_log.current_step),
tenant_id=str(job_log.tenant_id),
created_at=job_log.start_time,
estimated_duration_minutes=_estimate_duration(job_log.status, job_log.progress)
except ValueError as e:
logger.error(f"Training job validation error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except HTTPException:
raise
except Exception as e:
logger.error("Failed to get training job",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to get training job: {str(e)}")
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/progress", response_model=TrainingJobProgress)
async def get_training_progress(
job_id: str,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Get real-time training progress"""
try:
tenant_id_str = str(tenant_id)
logger.debug("Getting training progress",
job_id=job_id,
tenant_id=tenant_id_str)
# Verify job belongs to tenant
job = await training_service.get_training_job(job_id)
if job.tenant_id != tenant_id_str:
raise HTTPException(status_code=404, detail="Job not found")
progress = await training_service.get_job_progress(job_id)
return progress
except HTTPException:
raise
except Exception as e:
logger.error("Failed to get training progress",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to get training progress: {str(e)}")
@router.post("/tenants/{tenant_id}/jobs/{job_id}/cancel")
async def cancel_training_job(
job_id: str,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Cancel a running training job"""
try:
logger.info("Cancelling training job",
job_id=job_id,
tenant_id=tenant_id,
user_id=current_user["user_id"])
job = await training_service.get_training_job(job_id)
# Verify tenant access
if job.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Job not found")
await training_service.cancel_training_job(job_id)
# Publish cancellation event
try:
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error="Job cancelled by user",
failed_at="cancellation"
)
except Exception as e:
logger.warning("Failed to publish cancellation event", error=str(e))
logger.info("Training job cancelled", job_id=job_id)
return {"message": "Job cancelled successfully", "job_id": job_id}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to cancel training job",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
logger.error(f"Training job failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Training job failed"
)
@router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse)
async def train_single_product(
product_name: str,
async def start_single_product_training(
request: SingleProductTrainingRequest,
background_tasks: BackgroundTasks,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service),
db: AsyncSession = Depends(get_db_session)
tenant_id: str = Path(..., description="Tenant ID"),
product_name: str = Path(..., description="Product name"),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
):
"""Train model for a single product"""
"""
Start training for a single product.
Uses the same pipeline but filters for specific product.
"""
try:
logger.info("Training single product",
product_name=product_name,
tenant_id=tenant_id,
user_id=current_user["user_id"])
# Validate tenant access
if tenant_id != current_tenant:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Create training job for single product
job = await training_service.create_single_product_job(
db,
logger.info(f"Starting single product training for {product_name} (tenant {tenant_id})")
# Delegate to training service
result = await training_service.start_single_product_training(
tenant_id=tenant_id,
product_name=product_name,
config=request.dict()
sales_data=request.sales_data,
bakery_location=request.bakery_location or (40.4168, -3.7038),
weather_data=request.weather_data,
traffic_data=request.traffic_data,
job_id=request.job_id
)
# Publish event
try:
await publish_product_training_started(
job_id=job.job_id,
tenant_id=tenant_id,
product_name=product_name
)
except Exception as e:
logger.warning("Failed to publish product training event", error=str(e))
return TrainingJobResponse(**result)
# Start training in background
background_tasks.add_task(
training_service.execute_single_product_training,
job.job_id,
product_name
except ValueError as e:
logger.error(f"Single product training validation error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
logger.info("Single product training started",
job_id=job.job_id,
product_name=product_name)
return job
except Exception as e:
logger.error("Failed to train single product",
error=str(e),
product_name=product_name,
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to train product: {str(e)}")
@router.post("/tenants/{tenant_id}/training/validate", response_model=DataValidationResponse)
async def validate_training_data(
request: DataValidationRequest,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Validate data before training"""
try:
logger.debug("Validating training data",
tenant_id=tenant_id,
products=request.products)
validation_result = await training_service.validate_training_data(
tenant_id=tenant_id,
products=request.products,
min_data_points=request.min_data_points
logger.error(f"Single product training failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Single product training failed"
)
logger.debug("Data validation completed",
is_valid=validation_result.is_valid,
tenant_id=tenant_id)
return validation_result
except Exception as e:
logger.error("Failed to validate training data",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to validate data: {str(e)}")
@router.get("/tenants/{tenant_id}/models")
async def get_trained_models(
product_name: Optional[str] = Query(None),
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
@router.post("/tenants/{tenant_id}/training/jobs/{job_id}/cancel")
async def cancel_training_job(
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
):
"""Get list of trained models"""
"""
Cancel a running training job.
"""
try:
logger.debug("Getting trained models",
tenant_id=tenant_id,
product_name=product_name)
models = await training_service.get_trained_models(
tenant_id=tenant_id,
product_name=product_name
)
logger.debug("Retrieved trained models",
count=len(models),
tenant_id=tenant_id)
return models
except Exception as e:
logger.error("Failed to get trained models",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get models: {str(e)}")
@router.delete("/tenants/{tenant_id}/models/{model_id}")
@require_role("admin") # Only admins can delete models
async def delete_model(
model_id: str,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Delete a trained model (admin only)"""
try:
logger.info("Deleting model",
model_id=model_id,
tenant_id=tenant_id,
admin_id=current_user["user_id"])
# Verify model belongs to tenant
model = await training_service.get_model(model_id)
if model.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Model not found")
success = await training_service.delete_model(model_id)
if not success:
raise HTTPException(status_code=404, detail="Model not found")
logger.info("Model deleted successfully", model_id=model_id)
return {"message": "Model deleted successfully", "model_id": model_id}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to delete model",
error=str(e),
model_id=model_id)
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
@router.get("/tenants/{tenant_id}/stats")
async def get_training_stats(
start_date: Optional[datetime] = Query(None),
end_date: Optional[datetime] = Query(None),
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Get training statistics for tenant"""
try:
logger.debug("Getting training stats",
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date)
stats = await training_service.get_training_stats(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
logger.debug("Training stats retrieved", tenant_id=tenant_id)
return stats
except Exception as e:
logger.error("Failed to get training stats",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}")
@router.post("/tenants/{tenant_id}/retrain/all")
async def retrain_all_products(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Retrain all products with existing models"""
try:
logger.info("Retraining all products",
tenant_id=tenant_id,
user_id=current_user["user_id"])
# Check if models exist
existing_models = await training_service.get_trained_models(tenant_id)
if not existing_models:
# Validate tenant access
if tenant_id != current_tenant:
raise HTTPException(
status_code=400,
detail="No existing models found. Please run initial training first."
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Create retraining job
job = await training_service.create_training_job(
tenant_id=tenant_id,
user_id=current_user["user_id"],
config={**request.dict(), "is_retrain": True}
)
# TODO: Implement job cancellation
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
# Publish event
try:
await publish_job_started(
job_id=job.job_id,
tenant_id=tenant_id,
config={**request.dict(), "is_retrain": True}
)
except Exception as e:
logger.warning("Failed to publish retrain event", error=str(e))
return {"message": "Training job cancelled successfully"}
# Start retraining in background
background_tasks.add_task(
training_service.execute_training_job,
job.job_id
)
logger.info("Retraining job created", job_id=job.job_id)
return job
except HTTPException:
raise
except Exception as e:
logger.error("Failed to start retraining",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to start retraining: {str(e)}")
def _generate_status_message(status: str, current_step: str) -> str:
"""Generate appropriate status message"""
status_messages = {
"pending": "Training job is queued",
"running": f"Training in progress: {current_step}",
"completed": "Training completed successfully",
"failed": "Training failed",
"cancelled": "Training was cancelled"
}
return status_messages.get(status, f"Status: {status}")
logger.error(f"Failed to cancel training job: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cancel training job"
)
def _estimate_duration(status: str, progress: int) -> int:
"""Estimate remaining duration in minutes"""
if status == "completed":
return 0
elif status == "failed" or status == "cancelled":
return 0
elif status == "pending":
return 30 # Default estimate
else: # running
if progress > 0:
# Rough estimate based on progress
remaining_progress = 100 - progress
return max(1, int((remaining_progress / max(progress, 1)) * 10))
else:
return 25 # Default for running jobs
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/logs")
async def get_training_logs(
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
limit: int = Query(100, description="Number of log entries to return"),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get training job logs.
"""
try:
# Validate tenant access
if tenant_id != current_tenant:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# TODO: Implement log retrieval
return {
"job_id": job_id,
"logs": [
f"Training job {job_id} started",
"Data preprocessing completed",
"Model training completed",
"Training job finished successfully"
]
}
except Exception as e:
logger.error(f"Failed to get training logs: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training logs"
)
@router.get("/health")
async def health_check():
"""
Health check endpoint for the training service.
"""
return {
"status": "healthy",
"service": "training",
"version": "1.0.0",
"timestamp": datetime.now().isoformat()
}