Files
bakery-ia/services/training/app/api/models.py

494 lines
19 KiB
Python
Raw Normal View History

"""
Models API endpoints
"""
2025-07-29 19:11:36 +02:00
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query
from sqlalchemy.ext.asyncio import AsyncSession
2025-07-29 19:11:36 +02:00
from typing import List, Optional
2025-07-18 14:41:39 +02:00
import structlog
2025-07-29 18:37:23 +02:00
from sqlalchemy import text
from app.core.database import get_db
2025-07-29 19:11:36 +02:00
from app.schemas.training import TrainedModelResponse, ModelMetricsResponse
from app.services.training_service import EnhancedTrainingService
2025-11-14 07:23:56 +01:00
from datetime import datetime, timezone
2025-08-02 09:41:50 +02:00
from sqlalchemy import select, delete, func
import uuid
import shutil
2025-07-21 20:43:17 +02:00
from shared.auth.decorators import (
2025-08-02 09:41:50 +02:00
get_current_user_dep,
require_admin_role
2025-07-21 20:43:17 +02:00
)
2025-10-06 15:27:01 +02:00
from shared.routing import RouteBuilder
from shared.auth.access_control import (
require_user_role,
admin_role_required,
owner_role_required,
require_subscription_tier,
analytics_tier_required,
enterprise_tier_required
)
# Create route builder for consistent URL structure
route_builder = RouteBuilder('training')
2025-07-19 16:59:37 +02:00
2025-07-18 14:41:39 +02:00
logger = structlog.get_logger()
router = APIRouter()
training_service = EnhancedTrainingService()
2025-10-06 15:27:01 +02:00
@router.get(
2025-11-05 13:34:56 +01:00
route_builder.build_base_route("models") + "/{inventory_product_id}/active",
response_model=TrainedModelResponse
2025-10-06 15:27:01 +02:00
)
2025-07-28 19:28:39 +02:00
async def get_active_model(
tenant_id: str = Path(..., description="Tenant ID"),
2025-08-14 16:47:34 +02:00
inventory_product_id: str = Path(..., description="Inventory product UUID"),
db: AsyncSession = Depends(get_db)
):
2025-07-28 19:28:39 +02:00
"""
Get the active model for a product - used by forecasting service
"""
try:
2025-08-14 16:47:34 +02:00
logger.debug("Getting active model", tenant_id=tenant_id, inventory_product_id=inventory_product_id)
2025-08-08 09:08:41 +02:00
# ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0 and add case-insensitive product name matching
2025-07-29 18:37:23 +02:00
query = text("""
2025-07-28 19:28:39 +02:00
SELECT * FROM trained_models
WHERE tenant_id = :tenant_id
2025-08-14 16:47:34 +02:00
AND inventory_product_id = :inventory_product_id
2025-07-28 19:28:39 +02:00
AND is_active = true
AND is_production = true
ORDER BY created_at DESC
LIMIT 1
2025-07-29 18:37:23 +02:00
""")
2025-07-28 19:28:39 +02:00
result = await db.execute(query, {
"tenant_id": tenant_id,
2025-08-14 16:47:34 +02:00
"inventory_product_id": inventory_product_id
2025-07-28 19:28:39 +02:00
})
model_record = result.fetchone()
if not model_record:
2025-08-14 16:47:34 +02:00
logger.info("No active model found", tenant_id=tenant_id, inventory_product_id=inventory_product_id)
2025-07-28 19:28:39 +02:00
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
2025-08-14 16:47:34 +02:00
detail=f"No active model found for product {inventory_product_id}"
2025-07-28 19:28:39 +02:00
)
2025-07-29 18:37:23 +02:00
# ✅ FIX: Wrap update query with text() too
update_query = text("""
2025-11-14 07:23:56 +01:00
UPDATE trained_models
SET last_used_at = :now
2025-07-28 19:28:39 +02:00
WHERE id = :model_id
2025-07-29 18:37:23 +02:00
""")
2025-11-14 07:23:56 +01:00
2025-07-28 19:28:39 +02:00
await db.execute(update_query, {
2025-11-14 07:23:56 +01:00
"now": datetime.now(timezone.utc),
2025-07-28 19:28:39 +02:00
"model_id": model_record.id
})
await db.commit()
return {
2025-11-05 13:34:56 +01:00
"model_id": str(model_record.id),
"tenant_id": str(model_record.tenant_id),
"inventory_product_id": str(model_record.inventory_product_id),
"model_type": model_record.model_type,
2025-07-28 19:28:39 +02:00
"model_path": model_record.model_path,
2025-11-05 13:34:56 +01:00
"version": 1, # Default version
"training_samples": model_record.training_samples or 0,
"features": model_record.features_used or [],
"hyperparameters": model_record.hyperparameters or {},
2025-07-28 19:28:39 +02:00
"training_metrics": {
2025-11-05 13:34:56 +01:00
"mape": model_record.mape or 0.0,
"mae": model_record.mae or 0.0,
"rmse": model_record.rmse or 0.0,
"r2_score": model_record.r2_score or 0.0
2025-07-28 19:28:39 +02:00
},
2025-11-05 13:34:56 +01:00
"is_active": model_record.is_active,
"created_at": model_record.created_at,
"data_period_start": model_record.training_start_date,
"data_period_end": model_record.training_end_date
2025-07-28 19:28:39 +02:00
}
2025-08-08 09:08:41 +02:00
except HTTPException:
raise
except Exception as e:
2025-08-08 09:08:41 +02:00
error_msg = str(e) if str(e) else f"{type(e).__name__}: {repr(e)}"
2025-08-14 16:47:34 +02:00
logger.error(f"Failed to get active model: {error_msg}", tenant_id=tenant_id, inventory_product_id=inventory_product_id)
2025-08-08 09:08:41 +02:00
# Handle client disconnection gracefully
if "EndOfStream" in str(type(e)) or "WouldBlock" in str(type(e)):
2025-08-14 16:47:34 +02:00
logger.info("Client disconnected during model retrieval", tenant_id=tenant_id, inventory_product_id=inventory_product_id)
2025-08-08 09:08:41 +02:00
raise HTTPException(
status_code=status.HTTP_408_REQUEST_TIMEOUT,
detail="Request connection closed"
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve model"
)
2025-07-29 19:11:36 +02:00
2025-10-06 15:27:01 +02:00
@router.get(
route_builder.build_nested_resource_route("models", "model_id", "metrics"),
response_model=ModelMetricsResponse
)
2025-07-29 19:11:36 +02:00
async def get_model_metrics(
model_id: str = Path(..., description="Model ID"),
db: AsyncSession = Depends(get_db)
):
"""
Get performance metrics for a specific model - used by forecasting service
"""
try:
# Query the model by ID
query = text("""
SELECT * FROM trained_models
WHERE id = :model_id
""")
result = await db.execute(query, {"model_id": model_id})
model_record = result.fetchone()
if not model_record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {model_id} not found"
)
# Return metrics in the format expected by forecasting service
metrics = {
2025-08-08 09:08:41 +02:00
"model_id": str(model_record.id),
2025-07-29 19:11:36 +02:00
"accuracy": model_record.r2_score or 0.0, # Use R2 as accuracy measure
"mape": model_record.mape or 0.0,
"mae": model_record.mae or 0.0,
"rmse": model_record.rmse or 0.0,
"r2_score": model_record.r2_score or 0.0,
"training_samples": model_record.training_samples or 0,
"features_used": model_record.features_used or [],
"model_type": model_record.model_type,
"created_at": model_record.created_at.isoformat() if model_record.created_at else None,
"last_used_at": model_record.last_used_at.isoformat() if model_record.last_used_at else None
}
logger.info(f"Retrieved metrics for model {model_id}",
mape=metrics["mape"],
accuracy=metrics["accuracy"])
return metrics
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get model metrics: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve model metrics"
)
2025-10-06 15:27:01 +02:00
@router.get(
route_builder.build_base_route("models"),
response_model=List[TrainedModelResponse]
)
2025-07-29 19:11:36 +02:00
async def list_models(
tenant_id: str = Path(..., description="Tenant ID"),
status: Optional[str] = Query(None, description="Filter by status (active/inactive)"),
model_type: Optional[str] = Query(None, description="Filter by model type"),
limit: int = Query(50, ge=1, le=100, description="Maximum number of models to return"),
db: AsyncSession = Depends(get_db)
):
"""
List models for a tenant - used by forecasting service for model discovery
"""
try:
# Build query with filters
query_parts = ["SELECT * FROM trained_models WHERE tenant_id = :tenant_id"]
params = {"tenant_id": tenant_id}
if status == "deployed" or status == "active":
query_parts.append("AND is_active = true AND is_production = true")
elif status == "inactive":
query_parts.append("AND (is_active = false OR is_production = false)")
if model_type:
query_parts.append("AND model_type = :model_type")
params["model_type"] = model_type
query_parts.append("ORDER BY created_at DESC LIMIT :limit")
params["limit"] = limit
query = text(" ".join(query_parts))
result = await db.execute(query, params)
model_records = result.fetchall()
models = []
for record in model_records:
models.append({
2025-08-08 09:08:41 +02:00
"model_id": str(record.id),
"tenant_id": str(record.tenant_id),
2025-08-14 16:47:34 +02:00
"inventory_product_id": str(record.inventory_product_id),
2025-07-29 19:11:36 +02:00
"model_type": record.model_type,
"model_path": record.model_path,
"version": 1, # Default version
"training_samples": record.training_samples or 0,
"features": record.features_used or [],
"hyperparameters": record.hyperparameters or {},
"training_metrics": {
"mape": record.mape or 0.0,
"mae": record.mae or 0.0,
"rmse": record.rmse or 0.0,
"r2_score": record.r2_score or 0.0
},
"is_active": record.is_active,
"created_at": record.created_at,
"data_period_start": record.training_start_date,
"data_period_end": record.training_end_date
})
logger.info(f"Retrieved {len(models)} models for tenant {tenant_id}")
return models
except Exception as e:
logger.error(f"Failed to list models: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve models"
2025-08-02 09:41:50 +02:00
)
@router.delete("/models/tenant/{tenant_id}")
2025-10-06 15:27:01 +02:00
@require_user_role(['admin', 'owner'])
2025-08-02 09:41:50 +02:00
async def delete_tenant_models_complete(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""
Delete all trained models and artifacts for a tenant.
**WARNING: This operation is irreversible!**
This endpoint:
1. Cancels any active training jobs for the tenant
2. Deletes all model artifacts (files) from storage
3. Deletes model records from database
4. Deletes training logs and performance metrics
5. Publishes deletion event
Used by admin user deletion process to clean up all training data.
"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import (
ModelTrainingLog,
TrainedModel,
ModelArtifact,
ModelPerformanceMetric,
TrainingJobQueue
)
from app.core.config import settings
deletion_stats = {
"tenant_id": tenant_id,
2025-11-14 07:23:56 +01:00
"deleted_at": datetime.now(timezone.utc).isoformat(),
2025-08-02 09:41:50 +02:00
"jobs_cancelled": 0,
"models_deleted": 0,
"artifacts_deleted": 0,
"artifacts_files_deleted": 0,
"training_logs_deleted": 0,
"performance_metrics_deleted": 0,
"storage_freed_bytes": 0,
"errors": []
}
# Step 1: Cancel active training jobs
try:
active_jobs_query = select(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_uuid,
TrainingJobQueue.status.in_(["queued", "running", "pending"])
)
active_jobs_result = await db.execute(active_jobs_query)
active_jobs = active_jobs_result.scalars().all()
for job in active_jobs:
job.status = "cancelled"
2025-11-14 07:23:56 +01:00
job.updated_at = datetime.now(timezone.utc)
2025-08-02 09:41:50 +02:00
deletion_stats["jobs_cancelled"] += 1
if active_jobs:
await db.commit()
logger.info("Cancelled active training jobs",
tenant_id=tenant_id,
count=len(active_jobs))
except Exception as e:
error_msg = f"Error cancelling training jobs: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 2: Delete model artifact files from storage
try:
artifacts_query = select(ModelArtifact).where(
ModelArtifact.tenant_id == tenant_uuid
)
artifacts_result = await db.execute(artifacts_query)
artifacts = artifacts_result.scalars().all()
storage_freed = 0
files_deleted = 0
for artifact in artifacts:
try:
file_path = Path(artifact.file_path)
if file_path.exists():
file_size = file_path.stat().st_size
file_path.unlink() # Delete file
storage_freed += file_size
files_deleted += 1
logger.debug("Deleted artifact file",
file_path=str(file_path),
size_bytes=file_size)
# Also try to delete parent directories if empty
try:
if file_path.parent.exists() and not any(file_path.parent.iterdir()):
file_path.parent.rmdir()
except:
pass # Ignore errors cleaning up directories
except Exception as e:
error_msg = f"Error deleting artifact file {artifact.file_path}: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.warning(error_msg)
deletion_stats["artifacts_files_deleted"] = files_deleted
deletion_stats["storage_freed_bytes"] = storage_freed
logger.info("Deleted artifact files",
tenant_id=tenant_id,
files_deleted=files_deleted,
storage_freed_mb=storage_freed / (1024 * 1024))
except Exception as e:
error_msg = f"Error processing artifact files: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 3: Delete database records
try:
# Delete model performance metrics
metrics_count_query = select(func.count(ModelPerformanceMetric.id)).where(
ModelPerformanceMetric.tenant_id == tenant_uuid
)
metrics_count_result = await db.execute(metrics_count_query)
metrics_count = metrics_count_result.scalar()
metrics_delete_query = delete(ModelPerformanceMetric).where(
ModelPerformanceMetric.tenant_id == tenant_uuid
)
await db.execute(metrics_delete_query)
deletion_stats["performance_metrics_deleted"] = metrics_count
# Delete model artifacts records
artifacts_count_query = select(func.count(ModelArtifact.id)).where(
ModelArtifact.tenant_id == tenant_uuid
)
artifacts_count_result = await db.execute(artifacts_count_query)
artifacts_count = artifacts_count_result.scalar()
artifacts_delete_query = delete(ModelArtifact).where(
ModelArtifact.tenant_id == tenant_uuid
)
await db.execute(artifacts_delete_query)
deletion_stats["artifacts_deleted"] = artifacts_count
# Delete trained models
models_count_query = select(func.count(TrainedModel.id)).where(
TrainedModel.tenant_id == tenant_uuid
)
models_count_result = await db.execute(models_count_query)
models_count = models_count_result.scalar()
models_delete_query = delete(TrainedModel).where(
TrainedModel.tenant_id == tenant_uuid
)
await db.execute(models_delete_query)
deletion_stats["models_deleted"] = models_count
# Delete training logs
logs_count_query = select(func.count(ModelTrainingLog.id)).where(
ModelTrainingLog.tenant_id == tenant_uuid
)
logs_count_result = await db.execute(logs_count_query)
logs_count = logs_count_result.scalar()
logs_delete_query = delete(ModelTrainingLog).where(
ModelTrainingLog.tenant_id == tenant_uuid
)
await db.execute(logs_delete_query)
deletion_stats["training_logs_deleted"] = logs_count
# Delete job queue entries
queue_delete_query = delete(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_uuid
)
await db.execute(queue_delete_query)
await db.commit()
logger.info("Deleted training database records",
tenant_id=tenant_id,
models=models_count,
artifacts=artifacts_count,
logs=logs_count,
metrics=metrics_count)
except Exception as e:
await db.rollback()
error_msg = f"Error deleting database records: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=error_msg
)
# Step 4: Clean up tenant model directory
try:
tenant_model_dir = Path(settings.MODEL_STORAGE_PATH) / tenant_id
if tenant_model_dir.exists():
shutil.rmtree(tenant_model_dir)
logger.info("Deleted tenant model directory",
directory=str(tenant_model_dir))
except Exception as e:
error_msg = f"Error deleting model directory: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.warning(error_msg)
# Models deleted successfully
2025-08-02 09:41:50 +02:00
return {
"success": True,
"message": f"All training data for tenant {tenant_id} deleted successfully",
"deletion_details": deletion_stats
}
except HTTPException:
raise
except Exception as e:
logger.error("Unexpected error deleting tenant models",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete tenant models: {str(e)}"
)