Files
bakery-ia/services/training/app/api/models.py
2025-07-29 18:37:23 +02:00

95 lines
3.1 KiB
Python

"""
Models API endpoints
"""
from fastapi import APIRouter, Depends, HTTPException, status, Path
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
import structlog
from sqlalchemy import text
from app.core.database import get_db
from app.schemas.training import TrainedModelResponse
from app.services.training_service import TrainingService
from datetime import datetime
from shared.auth.decorators import (
get_current_tenant_id_dep
)
logger = structlog.get_logger()
router = APIRouter()
training_service = TrainingService()
@router.get("/tenants/{tenant_id}/models/{product_name}/active")
async def get_active_model(
tenant_id: str = Path(..., description="Tenant ID"),
product_name: str = Path(..., description="Product name"),
db: AsyncSession = Depends(get_db)
):
"""
Get the active model for a product - used by forecasting service
"""
try:
# ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0
query = text("""
SELECT * FROM trained_models
WHERE tenant_id = :tenant_id
AND product_name = :product_name
AND is_active = true
AND is_production = true
ORDER BY created_at DESC
LIMIT 1
""")
result = await db.execute(query, {
"tenant_id": tenant_id,
"product_name": product_name
})
model_record = result.fetchone()
if not model_record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No active model found for product {product_name}"
)
# ✅ FIX: Wrap update query with text() too
update_query = text("""
UPDATE trained_models
SET last_used_at = :now
WHERE id = :model_id
""")
await db.execute(update_query, {
"now": datetime.utcnow(),
"model_id": model_record.id
})
await db.commit()
return {
"model_id": model_record.id,
"model_path": model_record.model_path,
"features_used": model_record.features_used,
"hyperparameters": model_record.hyperparameters,
"training_metrics": {
"mape": model_record.mape,
"mae": model_record.mae,
"rmse": model_record.rmse,
"r2_score": model_record.r2_score
},
"created_at": model_record.created_at.isoformat() if model_record.created_at else None,
"training_period": {
"start_date": model_record.training_start_date.isoformat() if model_record.training_start_date else None,
"end_date": model_record.training_end_date.isoformat() if model_record.training_end_date else None
}
}
except Exception as e:
logger.error(f"Failed to get active model: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve model"
)