REFACTOR - Database logic
This commit is contained in:
@@ -38,11 +38,12 @@ async def get_active_model(
|
||||
Get the active model for a product - used by forecasting service
|
||||
"""
|
||||
try:
|
||||
# ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0
|
||||
logger.debug("Getting active model", tenant_id=tenant_id, product_name=product_name)
|
||||
# ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0 and add case-insensitive product name matching
|
||||
query = text("""
|
||||
SELECT * FROM trained_models
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND product_name = :product_name
|
||||
AND LOWER(product_name) = LOWER(:product_name)
|
||||
AND is_active = true
|
||||
AND is_production = true
|
||||
ORDER BY created_at DESC
|
||||
@@ -57,6 +58,7 @@ async def get_active_model(
|
||||
model_record = result.fetchone()
|
||||
|
||||
if not model_record:
|
||||
logger.info("No active model found", tenant_id=tenant_id, product_name=product_name)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No active model found for product {product_name}"
|
||||
@@ -76,7 +78,7 @@ async def get_active_model(
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"model_id": model_record.id, # ✅ This is the correct field name
|
||||
"model_id": str(model_record.id), # ✅ This is the correct field name
|
||||
"model_path": model_record.model_path,
|
||||
"features_used": model_record.features_used,
|
||||
"hyperparameters": model_record.hyperparameters,
|
||||
@@ -93,12 +95,24 @@ async def get_active_model(
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get active model: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve model"
|
||||
)
|
||||
error_msg = str(e) if str(e) else f"{type(e).__name__}: {repr(e)}"
|
||||
logger.error(f"Failed to get active model: {error_msg}", tenant_id=tenant_id, product_name=product_name)
|
||||
|
||||
# Handle client disconnection gracefully
|
||||
if "EndOfStream" in str(type(e)) or "WouldBlock" in str(type(e)):
|
||||
logger.info("Client disconnected during model retrieval", tenant_id=tenant_id, product_name=product_name)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_408_REQUEST_TIMEOUT,
|
||||
detail="Request connection closed"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve model"
|
||||
)
|
||||
|
||||
@router.get("/tenants/{tenant_id}/models/{model_id}/metrics", response_model=ModelMetricsResponse)
|
||||
async def get_model_metrics(
|
||||
@@ -126,7 +140,7 @@ async def get_model_metrics(
|
||||
|
||||
# Return metrics in the format expected by forecasting service
|
||||
metrics = {
|
||||
"model_id": model_record.id,
|
||||
"model_id": str(model_record.id),
|
||||
"accuracy": model_record.r2_score or 0.0, # Use R2 as accuracy measure
|
||||
"mape": model_record.mape or 0.0,
|
||||
"mae": model_record.mae or 0.0,
|
||||
@@ -189,8 +203,8 @@ async def list_models(
|
||||
models = []
|
||||
for record in model_records:
|
||||
models.append({
|
||||
"model_id": record.id,
|
||||
"tenant_id": record.tenant_id,
|
||||
"model_id": str(record.id),
|
||||
"tenant_id": str(record.tenant_id),
|
||||
"product_name": record.product_name,
|
||||
"model_type": record.model_type,
|
||||
"model_path": record.model_path,
|
||||
|
||||
Reference in New Issue
Block a user