Files
bakery-ia/services/training/app/api/models.py

95 lines
3.1 KiB
Python
Raw Normal View History

"""
Models API endpoints
"""
2025-07-28 19:28:39 +02:00
from fastapi import APIRouter, Depends, HTTPException, status, Path
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
2025-07-18 14:41:39 +02:00
import structlog
2025-07-29 18:37:23 +02:00
from sqlalchemy import text
from app.core.database import get_db
from app.schemas.training import TrainedModelResponse
from app.services.training_service import TrainingService
2025-07-28 19:28:39 +02:00
from datetime import datetime
2025-07-21 20:43:17 +02:00
from shared.auth.decorators import (
get_current_tenant_id_dep
)
2025-07-19 16:59:37 +02:00
2025-07-18 14:41:39 +02:00
logger = structlog.get_logger()
router = APIRouter()
training_service = TrainingService()
2025-07-28 19:28:39 +02:00
@router.get("/tenants/{tenant_id}/models/{product_name}/active")
async def get_active_model(
tenant_id: str = Path(..., description="Tenant ID"),
product_name: str = Path(..., description="Product name"),
db: AsyncSession = Depends(get_db)
):
2025-07-28 19:28:39 +02:00
"""
Get the active model for a product - used by forecasting service
"""
try:
2025-07-29 18:37:23 +02:00
# ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0
query = text("""
2025-07-28 19:28:39 +02:00
SELECT * FROM trained_models
WHERE tenant_id = :tenant_id
AND product_name = :product_name
AND is_active = true
AND is_production = true
ORDER BY created_at DESC
LIMIT 1
2025-07-29 18:37:23 +02:00
""")
2025-07-28 19:28:39 +02:00
result = await db.execute(query, {
"tenant_id": tenant_id,
"product_name": product_name
})
model_record = result.fetchone()
if not model_record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No active model found for product {product_name}"
)
2025-07-29 18:37:23 +02:00
# ✅ FIX: Wrap update query with text() too
update_query = text("""
2025-07-28 19:28:39 +02:00
UPDATE trained_models
SET last_used_at = :now
WHERE id = :model_id
2025-07-29 18:37:23 +02:00
""")
2025-07-28 19:28:39 +02:00
await db.execute(update_query, {
"now": datetime.utcnow(),
"model_id": model_record.id
})
await db.commit()
return {
"model_id": model_record.id,
"model_path": model_record.model_path,
"features_used": model_record.features_used,
"hyperparameters": model_record.hyperparameters,
"training_metrics": {
"mape": model_record.mape,
"mae": model_record.mae,
"rmse": model_record.rmse,
"r2_score": model_record.r2_score
},
2025-07-29 18:37:23 +02:00
"created_at": model_record.created_at.isoformat() if model_record.created_at else None,
2025-07-28 19:28:39 +02:00
"training_period": {
2025-07-29 18:37:23 +02:00
"start_date": model_record.training_start_date.isoformat() if model_record.training_start_date else None,
"end_date": model_record.training_end_date.isoformat() if model_record.training_end_date else None
2025-07-28 19:28:39 +02:00
}
}
except Exception as e:
2025-07-28 19:28:39 +02:00
logger.error(f"Failed to get active model: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
2025-07-28 19:28:39 +02:00
detail="Failed to retrieve model"
)