""" Performance Repository Repository for model performance metrics operations """ from typing import Optional, List, Dict, Any from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_, text, desc from datetime import datetime, timedelta import structlog from .base import TrainingBaseRepository from app.models.training import ModelPerformanceMetric from shared.database.exceptions import DatabaseError, ValidationError logger = structlog.get_logger() class PerformanceRepository(TrainingBaseRepository): """Repository for model performance metrics operations""" def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 900): # Performance metrics are relatively stable, longer cache time (15 minutes) super().__init__(ModelPerformanceMetric, session, cache_ttl) async def create_performance_metric(self, metric_data: Dict[str, Any]) -> ModelPerformanceMetric: """Create a new performance metric record""" try: # Validate metric data validation_result = self._validate_training_data( metric_data, ["model_id", "tenant_id", "inventory_product_id"] ) if not validation_result["is_valid"]: raise ValidationError(f"Invalid metric data: {validation_result['errors']}") # Set measurement timestamp if not provided if "measured_at" not in metric_data: metric_data["measured_at"] = datetime.now() # Create metric record metric = await self.create(metric_data) logger.info("Performance metric created", model_id=metric.model_id, tenant_id=metric.tenant_id, inventory_product_id=str(metric.inventory_product_id)) return metric except ValidationError: raise except Exception as e: logger.error("Failed to create performance metric", model_id=metric_data.get("model_id"), error=str(e)) raise DatabaseError(f"Failed to create metric: {str(e)}") async def get_metrics_by_model( self, model_id: str, skip: int = 0, limit: int = 100 ) -> List[ModelPerformanceMetric]: """Get all performance metrics for a model""" try: return await self.get_multi( filters={"model_id": model_id}, skip=skip, limit=limit, order_by="measured_at", order_desc=True ) except Exception as e: logger.error("Failed to get metrics by model", model_id=model_id, error=str(e)) raise DatabaseError(f"Failed to get metrics: {str(e)}") async def get_latest_metric_for_model(self, model_id: str) -> Optional[ModelPerformanceMetric]: """Get the latest performance metric for a model""" try: metrics = await self.get_multi( filters={"model_id": model_id}, limit=1, order_by="measured_at", order_desc=True ) return metrics[0] if metrics else None except Exception as e: logger.error("Failed to get latest metric for model", model_id=model_id, error=str(e)) raise DatabaseError(f"Failed to get latest metric: {str(e)}") async def get_metrics_by_tenant_and_product( self, tenant_id: str, inventory_product_id: str, skip: int = 0, limit: int = 100 ) -> List[ModelPerformanceMetric]: """Get performance metrics for a tenant's product""" try: return await self.get_multi( filters={ "tenant_id": tenant_id, "inventory_product_id": inventory_product_id }, skip=skip, limit=limit, order_by="measured_at", order_desc=True ) except Exception as e: logger.error("Failed to get metrics by tenant and product", tenant_id=tenant_id, inventory_product_id=inventory_product_id, error=str(e)) raise DatabaseError(f"Failed to get metrics: {str(e)}") async def get_metrics_in_date_range( self, start_date: datetime, end_date: datetime, tenant_id: str = None, model_id: str = None, skip: int = 0, limit: int = 100 ) -> List[ModelPerformanceMetric]: """Get performance metrics within a date range""" try: # Build filters table_name = self.model.__tablename__ conditions = ["measured_at >= :start_date", "measured_at <= :end_date"] params = {"start_date": start_date, "end_date": end_date, "limit": limit, "skip": skip} if tenant_id: conditions.append("tenant_id = :tenant_id") params["tenant_id"] = tenant_id if model_id: conditions.append("model_id = :model_id") params["model_id"] = model_id query_text = f""" SELECT * FROM {table_name} WHERE {' AND '.join(conditions)} ORDER BY measured_at DESC LIMIT :limit OFFSET :skip """ result = await self.session.execute(text(query_text), params) # Convert rows to model objects metrics = [] for row in result.fetchall(): record_dict = dict(row._mapping) metric = self.model(**record_dict) metrics.append(metric) return metrics except Exception as e: logger.error("Failed to get metrics in date range", start_date=start_date, end_date=end_date, error=str(e)) raise DatabaseError(f"Date range query failed: {str(e)}") async def get_performance_trends( self, tenant_id: str, inventory_product_id: str = None, days: int = 30 ) -> Dict[str, Any]: """Get performance trends for analysis""" try: start_date = datetime.now() - timedelta(days=days) end_date = datetime.now() # Build query for performance trends conditions = ["tenant_id = :tenant_id", "measured_at >= :start_date"] params = {"tenant_id": tenant_id, "start_date": start_date} if inventory_product_id: conditions.append("inventory_product_id = :inventory_product_id") params["inventory_product_id"] = inventory_product_id query_text = f""" SELECT inventory_product_id, AVG(mae) as avg_mae, AVG(mse) as avg_mse, AVG(rmse) as avg_rmse, AVG(mape) as avg_mape, AVG(r2_score) as avg_r2_score, AVG(accuracy_percentage) as avg_accuracy, COUNT(*) as measurement_count, MIN(measured_at) as first_measurement, MAX(measured_at) as last_measurement FROM model_performance_metrics WHERE {' AND '.join(conditions)} GROUP BY inventory_product_id ORDER BY avg_accuracy DESC """ result = await self.session.execute(text(query_text), params) trends = [] for row in result.fetchall(): trends.append({ "inventory_product_id": row.inventory_product_id, "metrics": { "avg_mae": float(row.avg_mae) if row.avg_mae else None, "avg_mse": float(row.avg_mse) if row.avg_mse else None, "avg_rmse": float(row.avg_rmse) if row.avg_rmse else None, "avg_mape": float(row.avg_mape) if row.avg_mape else None, "avg_r2_score": float(row.avg_r2_score) if row.avg_r2_score else None, "avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None }, "measurement_count": int(row.measurement_count), "period": { "start": row.first_measurement.isoformat() if row.first_measurement else None, "end": row.last_measurement.isoformat() if row.last_measurement else None, "days": days } }) return { "tenant_id": tenant_id, "inventory_product_id": inventory_product_id, "trends": trends, "period_days": days, "total_products": len(trends) } except Exception as e: logger.error("Failed to get performance trends", tenant_id=tenant_id, inventory_product_id=inventory_product_id, error=str(e)) return { "tenant_id": tenant_id, "inventory_product_id": inventory_product_id, "trends": [], "period_days": days, "total_products": 0 } async def get_best_performing_models( self, tenant_id: str, metric_type: str = "accuracy_percentage", limit: int = 10 ) -> List[Dict[str, Any]]: """Get best performing models based on a specific metric""" try: # Validate metric type valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"] if metric_type not in valid_metrics: metric_type = "accuracy_percentage" # For error metrics (mae, mse, rmse, mape), lower is better # For performance metrics (r2_score, accuracy_percentage), higher is better order_desc = metric_type in ["r2_score", "accuracy_percentage"] order_direction = "DESC" if order_desc else "ASC" query_text = f""" SELECT DISTINCT ON (inventory_product_id, model_id) model_id, inventory_product_id, {metric_type}, measured_at, evaluation_samples FROM model_performance_metrics WHERE tenant_id = :tenant_id AND {metric_type} IS NOT NULL ORDER BY inventory_product_id, model_id, measured_at DESC, {metric_type} {order_direction} LIMIT :limit """ result = await self.session.execute(text(query_text), { "tenant_id": tenant_id, "limit": limit }) best_models = [] for row in result.fetchall(): best_models.append({ "model_id": row.model_id, "inventory_product_id": row.inventory_product_id, "metric_value": float(getattr(row, metric_type)), "metric_type": metric_type, "measured_at": row.measured_at.isoformat() if row.measured_at else None, "evaluation_samples": int(row.evaluation_samples) if row.evaluation_samples else None }) return best_models except Exception as e: logger.error("Failed to get best performing models", tenant_id=tenant_id, metric_type=metric_type, error=str(e)) return [] async def cleanup_old_metrics(self, days_old: int = 180) -> int: """Clean up old performance metrics""" return await self.cleanup_old_records(days_old=days_old) async def get_metric_statistics(self, tenant_id: str) -> Dict[str, Any]: """Get performance metric statistics for a tenant""" try: # Get basic counts total_metrics = await self.count(filters={"tenant_id": tenant_id}) # Get metrics by product using raw query product_query = text(""" SELECT inventory_product_id, COUNT(*) as metric_count, AVG(accuracy_percentage) as avg_accuracy FROM model_performance_metrics WHERE tenant_id = :tenant_id GROUP BY inventory_product_id ORDER BY avg_accuracy DESC """) result = await self.session.execute(product_query, {"tenant_id": tenant_id}) product_stats = {} for row in result.fetchall(): product_stats[row.inventory_product_id] = { "metric_count": row.metric_count, "avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None } # Recent activity (metrics in last 7 days) seven_days_ago = datetime.now() - timedelta(days=7) recent_metrics = len(await self.get_records_by_date_range( seven_days_ago, datetime.now(), limit=1000 # High limit to get accurate count )) return { "total_metrics": total_metrics, "products_tracked": len(product_stats), "metrics_by_product": product_stats, "recent_metrics_7d": recent_metrics } except Exception as e: logger.error("Failed to get metric statistics", tenant_id=tenant_id, error=str(e)) return { "total_metrics": 0, "products_tracked": 0, "metrics_by_product": {}, "recent_metrics_7d": 0 } async def compare_model_performance( self, model_ids: List[str], metric_type: str = "accuracy_percentage" ) -> Dict[str, Any]: """Compare performance between multiple models""" try: if not model_ids or len(model_ids) < 2: return {"error": "At least 2 model IDs required for comparison"} # Validate metric type valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"] if metric_type not in valid_metrics: metric_type = "accuracy_percentage" model_ids_str = "', '".join(model_ids) query_text = f""" SELECT model_id, inventory_product_id, AVG({metric_type}) as avg_metric, MIN({metric_type}) as min_metric, MAX({metric_type}) as max_metric, COUNT(*) as measurement_count, MAX(measured_at) as latest_measurement FROM model_performance_metrics WHERE model_id IN ('{model_ids_str}') AND {metric_type} IS NOT NULL GROUP BY model_id, inventory_product_id ORDER BY avg_metric DESC """ result = await self.session.execute(text(query_text)) comparisons = [] for row in result.fetchall(): comparisons.append({ "model_id": row.model_id, "inventory_product_id": row.inventory_product_id, "avg_metric": float(row.avg_metric), "min_metric": float(row.min_metric), "max_metric": float(row.max_metric), "measurement_count": int(row.measurement_count), "latest_measurement": row.latest_measurement.isoformat() if row.latest_measurement else None }) # Find best and worst performing models if comparisons: best_model = max(comparisons, key=lambda x: x["avg_metric"]) worst_model = min(comparisons, key=lambda x: x["avg_metric"]) else: best_model = worst_model = None return { "metric_type": metric_type, "models_compared": len(set(comp["model_id"] for comp in comparisons)), "comparisons": comparisons, "best_performing": best_model, "worst_performing": worst_model } except Exception as e: logger.error("Failed to compare model performance", model_ids=model_ids, metric_type=metric_type, error=str(e)) return {"error": f"Comparison failed: {str(e)}"}