Files
bakery-ia/services/training/app/api/models.py

215 lines
7.9 KiB
Python
Raw Normal View History

"""
Models API endpoints
"""
2025-07-29 19:11:36 +02:00
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query
from sqlalchemy.ext.asyncio import AsyncSession
2025-07-29 19:11:36 +02:00
from typing import List, Optional
2025-07-18 14:41:39 +02:00
import structlog
2025-07-29 18:37:23 +02:00
from sqlalchemy import text
from app.core.database import get_db
2025-07-29 19:11:36 +02:00
from app.schemas.training import TrainedModelResponse, ModelMetricsResponse
from app.services.training_service import TrainingService
2025-07-28 19:28:39 +02:00
from datetime import datetime
2025-07-21 20:43:17 +02:00
from shared.auth.decorators import (
get_current_tenant_id_dep
)
2025-07-19 16:59:37 +02:00
2025-07-18 14:41:39 +02:00
logger = structlog.get_logger()
router = APIRouter()
training_service = TrainingService()
2025-07-28 19:28:39 +02:00
@router.get("/tenants/{tenant_id}/models/{product_name}/active")
async def get_active_model(
tenant_id: str = Path(..., description="Tenant ID"),
product_name: str = Path(..., description="Product name"),
db: AsyncSession = Depends(get_db)
):
2025-07-28 19:28:39 +02:00
"""
Get the active model for a product - used by forecasting service
"""
try:
2025-07-29 18:37:23 +02:00
# ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0
query = text("""
2025-07-28 19:28:39 +02:00
SELECT * FROM trained_models
WHERE tenant_id = :tenant_id
AND product_name = :product_name
AND is_active = true
AND is_production = true
ORDER BY created_at DESC
LIMIT 1
2025-07-29 18:37:23 +02:00
""")
2025-07-28 19:28:39 +02:00
result = await db.execute(query, {
"tenant_id": tenant_id,
"product_name": product_name
})
model_record = result.fetchone()
if not model_record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No active model found for product {product_name}"
)
2025-07-29 18:37:23 +02:00
# ✅ FIX: Wrap update query with text() too
update_query = text("""
2025-07-28 19:28:39 +02:00
UPDATE trained_models
SET last_used_at = :now
WHERE id = :model_id
2025-07-29 18:37:23 +02:00
""")
2025-07-28 19:28:39 +02:00
await db.execute(update_query, {
"now": datetime.utcnow(),
"model_id": model_record.id
})
await db.commit()
return {
2025-07-29 19:11:36 +02:00
"model_id": model_record.id, # ✅ This is the correct field name
2025-07-28 19:28:39 +02:00
"model_path": model_record.model_path,
"features_used": model_record.features_used,
"hyperparameters": model_record.hyperparameters,
"training_metrics": {
"mape": model_record.mape,
"mae": model_record.mae,
"rmse": model_record.rmse,
"r2_score": model_record.r2_score
},
2025-07-29 18:37:23 +02:00
"created_at": model_record.created_at.isoformat() if model_record.created_at else None,
2025-07-28 19:28:39 +02:00
"training_period": {
2025-07-29 18:37:23 +02:00
"start_date": model_record.training_start_date.isoformat() if model_record.training_start_date else None,
"end_date": model_record.training_end_date.isoformat() if model_record.training_end_date else None
2025-07-28 19:28:39 +02:00
}
}
except Exception as e:
2025-07-28 19:28:39 +02:00
logger.error(f"Failed to get active model: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
2025-07-28 19:28:39 +02:00
detail="Failed to retrieve model"
2025-07-29 19:11:36 +02:00
)
@router.get("/tenants/{tenant_id}/models/{model_id}/metrics", response_model=ModelMetricsResponse)
async def get_model_metrics(
model_id: str = Path(..., description="Model ID"),
db: AsyncSession = Depends(get_db)
):
"""
Get performance metrics for a specific model - used by forecasting service
"""
try:
# Query the model by ID
query = text("""
SELECT * FROM trained_models
WHERE id = :model_id
""")
result = await db.execute(query, {"model_id": model_id})
model_record = result.fetchone()
if not model_record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {model_id} not found"
)
# Return metrics in the format expected by forecasting service
metrics = {
"model_id": model_record.id,
"accuracy": model_record.r2_score or 0.0, # Use R2 as accuracy measure
"mape": model_record.mape or 0.0,
"mae": model_record.mae or 0.0,
"rmse": model_record.rmse or 0.0,
"r2_score": model_record.r2_score or 0.0,
"training_samples": model_record.training_samples or 0,
"features_used": model_record.features_used or [],
"model_type": model_record.model_type,
"created_at": model_record.created_at.isoformat() if model_record.created_at else None,
"last_used_at": model_record.last_used_at.isoformat() if model_record.last_used_at else None
}
logger.info(f"Retrieved metrics for model {model_id}",
mape=metrics["mape"],
accuracy=metrics["accuracy"])
return metrics
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get model metrics: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve model metrics"
)
@router.get("/tenants/{tenant_id}/models", response_model=List[TrainedModelResponse])
async def list_models(
tenant_id: str = Path(..., description="Tenant ID"),
status: Optional[str] = Query(None, description="Filter by status (active/inactive)"),
model_type: Optional[str] = Query(None, description="Filter by model type"),
limit: int = Query(50, ge=1, le=100, description="Maximum number of models to return"),
db: AsyncSession = Depends(get_db)
):
"""
List models for a tenant - used by forecasting service for model discovery
"""
try:
# Build query with filters
query_parts = ["SELECT * FROM trained_models WHERE tenant_id = :tenant_id"]
params = {"tenant_id": tenant_id}
if status == "deployed" or status == "active":
query_parts.append("AND is_active = true AND is_production = true")
elif status == "inactive":
query_parts.append("AND (is_active = false OR is_production = false)")
if model_type:
query_parts.append("AND model_type = :model_type")
params["model_type"] = model_type
query_parts.append("ORDER BY created_at DESC LIMIT :limit")
params["limit"] = limit
query = text(" ".join(query_parts))
result = await db.execute(query, params)
model_records = result.fetchall()
models = []
for record in model_records:
models.append({
"model_id": record.id,
"tenant_id": record.tenant_id,
"product_name": record.product_name,
"model_type": record.model_type,
"model_path": record.model_path,
"version": 1, # Default version
"training_samples": record.training_samples or 0,
"features": record.features_used or [],
"hyperparameters": record.hyperparameters or {},
"training_metrics": {
"mape": record.mape or 0.0,
"mae": record.mae or 0.0,
"rmse": record.rmse or 0.0,
"r2_score": record.r2_score or 0.0
},
"is_active": record.is_active,
"created_at": record.created_at,
"data_period_start": record.training_start_date,
"data_period_end": record.training_end_date
})
logger.info(f"Retrieved {len(models)} models for tenant {tenant_id}")
return models
except Exception as e:
logger.error(f"Failed to list models: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve models"
)