""" Models API endpoints """ from fastapi import APIRouter, Depends, HTTPException, status, Path, Query from sqlalchemy.ext.asyncio import AsyncSession from typing import List, Optional import structlog from sqlalchemy import text from app.core.database import get_db from app.schemas.training import TrainedModelResponse, ModelMetricsResponse from app.services.training_service import TrainingService from datetime import datetime from shared.auth.decorators import ( get_current_tenant_id_dep ) logger = structlog.get_logger() router = APIRouter() training_service = TrainingService() @router.get("/tenants/{tenant_id}/models/{product_name}/active") async def get_active_model( tenant_id: str = Path(..., description="Tenant ID"), product_name: str = Path(..., description="Product name"), db: AsyncSession = Depends(get_db) ): """ Get the active model for a product - used by forecasting service """ try: # ✅ FIX: Wrap SQL with text() for SQLAlchemy 2.0 query = text(""" SELECT * FROM trained_models WHERE tenant_id = :tenant_id AND product_name = :product_name AND is_active = true AND is_production = true ORDER BY created_at DESC LIMIT 1 """) result = await db.execute(query, { "tenant_id": tenant_id, "product_name": product_name }) model_record = result.fetchone() if not model_record: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"No active model found for product {product_name}" ) # ✅ FIX: Wrap update query with text() too update_query = text(""" UPDATE trained_models SET last_used_at = :now WHERE id = :model_id """) await db.execute(update_query, { "now": datetime.utcnow(), "model_id": model_record.id }) await db.commit() return { "model_id": model_record.id, # ✅ This is the correct field name "model_path": model_record.model_path, "features_used": model_record.features_used, "hyperparameters": model_record.hyperparameters, "training_metrics": { "mape": model_record.mape, "mae": model_record.mae, "rmse": model_record.rmse, "r2_score": model_record.r2_score }, "created_at": model_record.created_at.isoformat() if model_record.created_at else None, "training_period": { "start_date": model_record.training_start_date.isoformat() if model_record.training_start_date else None, "end_date": model_record.training_end_date.isoformat() if model_record.training_end_date else None } } except Exception as e: logger.error(f"Failed to get active model: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to retrieve model" ) @router.get("/tenants/{tenant_id}/models/{model_id}/metrics", response_model=ModelMetricsResponse) async def get_model_metrics( model_id: str = Path(..., description="Model ID"), db: AsyncSession = Depends(get_db) ): """ Get performance metrics for a specific model - used by forecasting service """ try: # Query the model by ID query = text(""" SELECT * FROM trained_models WHERE id = :model_id """) result = await db.execute(query, {"model_id": model_id}) model_record = result.fetchone() if not model_record: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Model {model_id} not found" ) # Return metrics in the format expected by forecasting service metrics = { "model_id": model_record.id, "accuracy": model_record.r2_score or 0.0, # Use R2 as accuracy measure "mape": model_record.mape or 0.0, "mae": model_record.mae or 0.0, "rmse": model_record.rmse or 0.0, "r2_score": model_record.r2_score or 0.0, "training_samples": model_record.training_samples or 0, "features_used": model_record.features_used or [], "model_type": model_record.model_type, "created_at": model_record.created_at.isoformat() if model_record.created_at else None, "last_used_at": model_record.last_used_at.isoformat() if model_record.last_used_at else None } logger.info(f"Retrieved metrics for model {model_id}", mape=metrics["mape"], accuracy=metrics["accuracy"]) return metrics except HTTPException: raise except Exception as e: logger.error(f"Failed to get model metrics: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to retrieve model metrics" ) @router.get("/tenants/{tenant_id}/models", response_model=List[TrainedModelResponse]) async def list_models( tenant_id: str = Path(..., description="Tenant ID"), status: Optional[str] = Query(None, description="Filter by status (active/inactive)"), model_type: Optional[str] = Query(None, description="Filter by model type"), limit: int = Query(50, ge=1, le=100, description="Maximum number of models to return"), db: AsyncSession = Depends(get_db) ): """ List models for a tenant - used by forecasting service for model discovery """ try: # Build query with filters query_parts = ["SELECT * FROM trained_models WHERE tenant_id = :tenant_id"] params = {"tenant_id": tenant_id} if status == "deployed" or status == "active": query_parts.append("AND is_active = true AND is_production = true") elif status == "inactive": query_parts.append("AND (is_active = false OR is_production = false)") if model_type: query_parts.append("AND model_type = :model_type") params["model_type"] = model_type query_parts.append("ORDER BY created_at DESC LIMIT :limit") params["limit"] = limit query = text(" ".join(query_parts)) result = await db.execute(query, params) model_records = result.fetchall() models = [] for record in model_records: models.append({ "model_id": record.id, "tenant_id": record.tenant_id, "product_name": record.product_name, "model_type": record.model_type, "model_path": record.model_path, "version": 1, # Default version "training_samples": record.training_samples or 0, "features": record.features_used or [], "hyperparameters": record.hyperparameters or {}, "training_metrics": { "mape": record.mape or 0.0, "mae": record.mae or 0.0, "rmse": record.rmse or 0.0, "r2_score": record.r2_score or 0.0 }, "is_active": record.is_active, "created_at": record.created_at, "data_period_start": record.training_start_date, "data_period_end": record.training_end_date }) logger.info(f"Retrieved {len(models)} models for tenant {tenant_id}") return models except Exception as e: logger.error(f"Failed to list models: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to retrieve models" )