From 3f63cc2a494f5a8673c37b45d4f809a7cb2bafc3 Mon Sep 17 00:00:00 2001 From: Urtzi Alfaro Date: Tue, 29 Jul 2025 19:11:36 +0200 Subject: [PATCH] Start fixing forecast service API 12 --- .../forecasting/app/services/model_client.py | 48 +++++-- services/training/app/api/models.py | 128 +++++++++++++++++- services/training/app/schemas/training.py | 21 ++- 3 files changed, 178 insertions(+), 19 deletions(-) diff --git a/services/forecasting/app/services/model_client.py b/services/forecasting/app/services/model_client.py index 3cb19922..71808775 100644 --- a/services/forecasting/app/services/model_client.py +++ b/services/forecasting/app/services/model_client.py @@ -73,20 +73,42 @@ class ModelClient: logger.warning("No trained models found", tenant_id=tenant_id) return None - # Get model metrics to validate quality - metrics = await self.clients.training.get_model_metrics( - tenant_id=tenant_id, - model_id=latest_model["id"] - ) - - if metrics and metrics.get("accuracy", 0) > 0.7: # 70% accuracy threshold - logger.info(f"Selected model {latest_model['id']} with accuracy {metrics.get('accuracy')}", - tenant_id=tenant_id) - return latest_model - else: - logger.warning(f"Model accuracy too low: {metrics.get('accuracy', 'unknown')}", - tenant_id=tenant_id) + # ✅ FIX 1: Use "model_id" instead of "id" + model_id = latest_model.get("model_id") + if not model_id: + logger.error("Model response missing model_id field", tenant_id=tenant_id) return None + + # ✅ FIX 2: Handle metrics endpoint failure gracefully + try: + # Get model metrics to validate quality + metrics = await self.clients.training.get_model_metrics( + tenant_id=tenant_id, + model_id=model_id + ) + + # If metrics call succeeded, check accuracy threshold + if metrics and metrics.get("accuracy", 0) > 0.7: # 70% accuracy threshold + logger.info(f"Selected model {model_id} with accuracy {metrics.get('accuracy')}", + tenant_id=tenant_id) + return latest_model + elif metrics: + logger.warning(f"Model accuracy too low: {metrics.get('accuracy', 'unknown')}", + tenant_id=tenant_id) + # Still return the model even if accuracy is low - better than no prediction + logger.info("Returning model despite low accuracy - no alternative available", + tenant_id=tenant_id) + return latest_model + else: + logger.warning("No metrics returned from training service", tenant_id=tenant_id) + # Return model anyway - metrics service might be temporarily down + return latest_model + + except Exception as metrics_error: + # ✅ FIX 3: If metrics endpoint fails, still return the model + logger.warning(f"Failed to get model metrics: {metrics_error}", tenant_id=tenant_id) + logger.info("Proceeding with model despite metrics failure", tenant_id=tenant_id) + return latest_model except Exception as e: logger.error(f"Error selecting best model: {e}", tenant_id=tenant_id) diff --git a/services/training/app/api/models.py b/services/training/app/api/models.py index acecb32f..e589b63e 100644 --- a/services/training/app/api/models.py +++ b/services/training/app/api/models.py @@ -2,14 +2,14 @@ Models API endpoints """ -from fastapi import APIRouter, Depends, HTTPException, status, Path +from fastapi import APIRouter, Depends, HTTPException, status, Path, Query from sqlalchemy.ext.asyncio import AsyncSession -from typing import List +from typing import List, Optional import structlog from sqlalchemy import text from app.core.database import get_db -from app.schemas.training import TrainedModelResponse +from app.schemas.training import TrainedModelResponse, ModelMetricsResponse from app.services.training_service import TrainingService from datetime import datetime @@ -70,7 +70,7 @@ async def get_active_model( await db.commit() return { - "model_id": model_record.id, + "model_id": model_record.id, # ✅ This is the correct field name "model_path": model_record.model_path, "features_used": model_record.features_used, "hyperparameters": model_record.hyperparameters, @@ -92,4 +92,124 @@ async def get_active_model( raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to retrieve model" + ) + +@router.get("/tenants/{tenant_id}/models/{model_id}/metrics", response_model=ModelMetricsResponse) +async def get_model_metrics( + model_id: str = Path(..., description="Model ID"), + db: AsyncSession = Depends(get_db) +): + """ + Get performance metrics for a specific model - used by forecasting service + """ + try: + # Query the model by ID + query = text(""" + SELECT * FROM trained_models + WHERE id = :model_id + """) + + result = await db.execute(query, {"model_id": model_id}) + model_record = result.fetchone() + + if not model_record: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Model {model_id} not found" + ) + + # Return metrics in the format expected by forecasting service + metrics = { + "model_id": model_record.id, + "accuracy": model_record.r2_score or 0.0, # Use R2 as accuracy measure + "mape": model_record.mape or 0.0, + "mae": model_record.mae or 0.0, + "rmse": model_record.rmse or 0.0, + "r2_score": model_record.r2_score or 0.0, + "training_samples": model_record.training_samples or 0, + "features_used": model_record.features_used or [], + "model_type": model_record.model_type, + "created_at": model_record.created_at.isoformat() if model_record.created_at else None, + "last_used_at": model_record.last_used_at.isoformat() if model_record.last_used_at else None + } + + logger.info(f"Retrieved metrics for model {model_id}", + mape=metrics["mape"], + accuracy=metrics["accuracy"]) + + return metrics + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to get model metrics: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to retrieve model metrics" + ) + +@router.get("/tenants/{tenant_id}/models", response_model=List[TrainedModelResponse]) +async def list_models( + tenant_id: str = Path(..., description="Tenant ID"), + status: Optional[str] = Query(None, description="Filter by status (active/inactive)"), + model_type: Optional[str] = Query(None, description="Filter by model type"), + limit: int = Query(50, ge=1, le=100, description="Maximum number of models to return"), + db: AsyncSession = Depends(get_db) +): + """ + List models for a tenant - used by forecasting service for model discovery + """ + try: + # Build query with filters + query_parts = ["SELECT * FROM trained_models WHERE tenant_id = :tenant_id"] + params = {"tenant_id": tenant_id} + + if status == "deployed" or status == "active": + query_parts.append("AND is_active = true AND is_production = true") + elif status == "inactive": + query_parts.append("AND (is_active = false OR is_production = false)") + + if model_type: + query_parts.append("AND model_type = :model_type") + params["model_type"] = model_type + + query_parts.append("ORDER BY created_at DESC LIMIT :limit") + params["limit"] = limit + + query = text(" ".join(query_parts)) + result = await db.execute(query, params) + model_records = result.fetchall() + + models = [] + for record in model_records: + models.append({ + "model_id": record.id, + "tenant_id": record.tenant_id, + "product_name": record.product_name, + "model_type": record.model_type, + "model_path": record.model_path, + "version": 1, # Default version + "training_samples": record.training_samples or 0, + "features": record.features_used or [], + "hyperparameters": record.hyperparameters or {}, + "training_metrics": { + "mape": record.mape or 0.0, + "mae": record.mae or 0.0, + "rmse": record.rmse or 0.0, + "r2_score": record.r2_score or 0.0 + }, + "is_active": record.is_active, + "created_at": record.created_at, + "data_period_start": record.training_start_date, + "data_period_end": record.training_end_date + }) + + logger.info(f"Retrieved {len(models)} models for tenant {tenant_id}") + return models + + except Exception as e: + logger.error(f"Failed to list models: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to retrieve models" ) \ No newline at end of file diff --git a/services/training/app/schemas/training.py b/services/training/app/schemas/training.py index 52a95cf8..04db04d9 100644 --- a/services/training/app/schemas/training.py +++ b/services/training/app/schemas/training.py @@ -352,11 +352,28 @@ class TrainingErrorUpdate(BaseModel): job_id: str = Field(..., description="Training job identifier") error: str = Field(..., description="Error message") timestamp: datetime = Field(default_factory=datetime.now, description="Error timestamp") - + + +class ModelMetricsResponse(BaseModel): + """Response schema for model performance metrics""" + model_id: str = Field(..., description="Unique model identifier") + accuracy: float = Field(..., description="Model accuracy (R2 score)", ge=0.0, le=1.0) + mape: float = Field(..., description="Mean Absolute Percentage Error") + mae: float = Field(..., description="Mean Absolute Error") + rmse: float = Field(..., description="Root Mean Square Error") + r2_score: float = Field(..., description="R-squared score") + training_samples: int = Field(..., description="Number of training samples used") + features_used: List[str] = Field(..., description="List of features used in training") + model_type: str = Field(..., description="Type of ML model") + created_at: Optional[str] = Field(None, description="Model creation timestamp") + last_used_at: Optional[str] = Field(None, description="Last time model was used") + + class Config: + from_attributes = True # Union type for all WebSocket messages TrainingWebSocketMessage = Union[ TrainingProgressUpdate, TrainingCompletedUpdate, TrainingErrorUpdate -] \ No newline at end of file +]