Improve training code

This commit is contained in:
Urtzi Alfaro
2025-07-28 19:28:39 +02:00
parent 946015b80c
commit 98f546af12
15 changed files with 2534 additions and 2812 deletions

View File

@@ -2,7 +2,7 @@
Models API endpoints
"""
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Path
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
import structlog
@@ -10,6 +10,7 @@ import structlog
from app.core.database import get_db
from app.schemas.training import TrainedModelResponse
from app.services.training_service import TrainingService
from datetime import datetime
from shared.auth.decorators import (
get_current_tenant_id_dep
@@ -20,17 +21,73 @@ router = APIRouter()
training_service = TrainingService()
@router.get("/tenants/{tenant_id}/", response_model=List[TrainedModelResponse])
async def get_trained_models(
tenant_id: str = Depends(get_current_tenant_id_dep),
@router.get("/tenants/{tenant_id}/models/{product_name}/active")
async def get_active_model(
tenant_id: str = Path(..., description="Tenant ID"),
product_name: str = Path(..., description="Product name"),
db: AsyncSession = Depends(get_db)
):
"""Get trained models"""
"""
Get the active model for a product - used by forecasting service
"""
try:
return await training_service.get_trained_models(tenant_id, db)
query = """
SELECT * FROM trained_models
WHERE tenant_id = :tenant_id
AND product_name = :product_name
AND is_active = true
AND is_production = true
ORDER BY created_at DESC
LIMIT 1
"""
result = await db.execute(query, {
"tenant_id": tenant_id,
"product_name": product_name
})
model_record = result.fetchone()
if not model_record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No active model found for product {product_name}"
)
# Update last_used_at
update_query = """
UPDATE trained_models
SET last_used_at = :now
WHERE id = :model_id
"""
await db.execute(update_query, {
"now": datetime.utcnow(),
"model_id": model_record.id
})
await db.commit()
return {
"model_id": model_record.id,
"model_path": model_record.model_path,
"features_used": model_record.features_used,
"hyperparameters": model_record.hyperparameters,
"training_metrics": {
"mape": model_record.mape,
"mae": model_record.mae,
"rmse": model_record.rmse,
"r2_score": model_record.r2_score
},
"created_at": model_record.created_at.isoformat(),
"training_period": {
"start_date": model_record.training_start_date.isoformat(),
"end_date": model_record.training_end_date.isoformat()
}
}
except Exception as e:
logger.error(f"Get trained models error: {e}")
logger.error(f"Failed to get active model: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get trained models"
detail="Failed to retrieve model"
)