Improve training code
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
Models API endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Path
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List
|
||||
import structlog
|
||||
@@ -10,6 +10,7 @@ import structlog
|
||||
from app.core.database import get_db
|
||||
from app.schemas.training import TrainedModelResponse
|
||||
from app.services.training_service import TrainingService
|
||||
from datetime import datetime
|
||||
|
||||
from shared.auth.decorators import (
|
||||
get_current_tenant_id_dep
|
||||
@@ -20,17 +21,73 @@ router = APIRouter()
|
||||
|
||||
training_service = TrainingService()
|
||||
|
||||
@router.get("/tenants/{tenant_id}/", response_model=List[TrainedModelResponse])
|
||||
async def get_trained_models(
|
||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||
@router.get("/tenants/{tenant_id}/models/{product_name}/active")
|
||||
async def get_active_model(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
product_name: str = Path(..., description="Product name"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get trained models"""
|
||||
"""
|
||||
Get the active model for a product - used by forecasting service
|
||||
"""
|
||||
try:
|
||||
return await training_service.get_trained_models(tenant_id, db)
|
||||
query = """
|
||||
SELECT * FROM trained_models
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND product_name = :product_name
|
||||
AND is_active = true
|
||||
AND is_production = true
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
result = await db.execute(query, {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name
|
||||
})
|
||||
|
||||
model_record = result.fetchone()
|
||||
|
||||
if not model_record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No active model found for product {product_name}"
|
||||
)
|
||||
|
||||
# Update last_used_at
|
||||
update_query = """
|
||||
UPDATE trained_models
|
||||
SET last_used_at = :now
|
||||
WHERE id = :model_id
|
||||
"""
|
||||
|
||||
await db.execute(update_query, {
|
||||
"now": datetime.utcnow(),
|
||||
"model_id": model_record.id
|
||||
})
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"model_id": model_record.id,
|
||||
"model_path": model_record.model_path,
|
||||
"features_used": model_record.features_used,
|
||||
"hyperparameters": model_record.hyperparameters,
|
||||
"training_metrics": {
|
||||
"mape": model_record.mape,
|
||||
"mae": model_record.mae,
|
||||
"rmse": model_record.rmse,
|
||||
"r2_score": model_record.r2_score
|
||||
},
|
||||
"created_at": model_record.created_at.isoformat(),
|
||||
"training_period": {
|
||||
"start_date": model_record.training_start_date.isoformat(),
|
||||
"end_date": model_record.training_end_date.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Get trained models error: {e}")
|
||||
logger.error(f"Failed to get active model: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get trained models"
|
||||
detail="Failed to retrieve model"
|
||||
)
|
||||
Reference in New Issue
Block a user