93 lines
2.8 KiB
Python
93 lines
2.8 KiB
Python
"""
|
|
Models API endpoints
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Path
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from typing import List
|
|
import structlog
|
|
|
|
from app.core.database import get_db
|
|
from app.schemas.training import TrainedModelResponse
|
|
from app.services.training_service import TrainingService
|
|
from datetime import datetime
|
|
|
|
from shared.auth.decorators import (
|
|
get_current_tenant_id_dep
|
|
)
|
|
|
|
logger = structlog.get_logger()
|
|
router = APIRouter()
|
|
|
|
training_service = TrainingService()
|
|
|
|
@router.get("/tenants/{tenant_id}/models/{product_name}/active")
|
|
async def get_active_model(
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
product_name: str = Path(..., description="Product name"),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Get the active model for a product - used by forecasting service
|
|
"""
|
|
try:
|
|
query = """
|
|
SELECT * FROM trained_models
|
|
WHERE tenant_id = :tenant_id
|
|
AND product_name = :product_name
|
|
AND is_active = true
|
|
AND is_production = true
|
|
ORDER BY created_at DESC
|
|
LIMIT 1
|
|
"""
|
|
|
|
result = await db.execute(query, {
|
|
"tenant_id": tenant_id,
|
|
"product_name": product_name
|
|
})
|
|
|
|
model_record = result.fetchone()
|
|
|
|
if not model_record:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"No active model found for product {product_name}"
|
|
)
|
|
|
|
# Update last_used_at
|
|
update_query = """
|
|
UPDATE trained_models
|
|
SET last_used_at = :now
|
|
WHERE id = :model_id
|
|
"""
|
|
|
|
await db.execute(update_query, {
|
|
"now": datetime.utcnow(),
|
|
"model_id": model_record.id
|
|
})
|
|
await db.commit()
|
|
|
|
return {
|
|
"model_id": model_record.id,
|
|
"model_path": model_record.model_path,
|
|
"features_used": model_record.features_used,
|
|
"hyperparameters": model_record.hyperparameters,
|
|
"training_metrics": {
|
|
"mape": model_record.mape,
|
|
"mae": model_record.mae,
|
|
"rmse": model_record.rmse,
|
|
"r2_score": model_record.r2_score
|
|
},
|
|
"created_at": model_record.created_at.isoformat(),
|
|
"training_period": {
|
|
"start_date": model_record.training_start_date.isoformat(),
|
|
"end_date": model_record.training_end_date.isoformat()
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get active model: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to retrieve model"
|
|
) |