433 lines
17 KiB
Python
433 lines
17 KiB
Python
"""
|
|
Performance Repository
|
|
Repository for model performance metrics operations
|
|
"""
|
|
|
|
from typing import Optional, List, Dict, Any
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_, text, desc
|
|
from datetime import datetime, timedelta
|
|
import structlog
|
|
|
|
from .base import TrainingBaseRepository
|
|
from app.models.training import ModelPerformanceMetric
|
|
from shared.database.exceptions import DatabaseError, ValidationError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class PerformanceRepository(TrainingBaseRepository):
|
|
"""Repository for model performance metrics operations"""
|
|
|
|
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 900):
|
|
# Performance metrics are relatively stable, longer cache time (15 minutes)
|
|
super().__init__(ModelPerformanceMetric, session, cache_ttl)
|
|
|
|
async def create_performance_metric(self, metric_data: Dict[str, Any]) -> ModelPerformanceMetric:
|
|
"""Create a new performance metric record"""
|
|
try:
|
|
# Validate metric data
|
|
validation_result = self._validate_training_data(
|
|
metric_data,
|
|
["model_id", "tenant_id", "product_name"]
|
|
)
|
|
|
|
if not validation_result["is_valid"]:
|
|
raise ValidationError(f"Invalid metric data: {validation_result['errors']}")
|
|
|
|
# Set measurement timestamp if not provided
|
|
if "measured_at" not in metric_data:
|
|
metric_data["measured_at"] = datetime.now()
|
|
|
|
# Create metric record
|
|
metric = await self.create(metric_data)
|
|
|
|
logger.info("Performance metric created",
|
|
model_id=metric.model_id,
|
|
tenant_id=metric.tenant_id,
|
|
product_name=metric.product_name)
|
|
|
|
return metric
|
|
|
|
except ValidationError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to create performance metric",
|
|
model_id=metric_data.get("model_id"),
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to create metric: {str(e)}")
|
|
|
|
async def get_metrics_by_model(
|
|
self,
|
|
model_id: str,
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> List[ModelPerformanceMetric]:
|
|
"""Get all performance metrics for a model"""
|
|
try:
|
|
return await self.get_multi(
|
|
filters={"model_id": model_id},
|
|
skip=skip,
|
|
limit=limit,
|
|
order_by="measured_at",
|
|
order_desc=True
|
|
)
|
|
except Exception as e:
|
|
logger.error("Failed to get metrics by model",
|
|
model_id=model_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get metrics: {str(e)}")
|
|
|
|
async def get_latest_metric_for_model(self, model_id: str) -> Optional[ModelPerformanceMetric]:
|
|
"""Get the latest performance metric for a model"""
|
|
try:
|
|
metrics = await self.get_multi(
|
|
filters={"model_id": model_id},
|
|
limit=1,
|
|
order_by="measured_at",
|
|
order_desc=True
|
|
)
|
|
return metrics[0] if metrics else None
|
|
except Exception as e:
|
|
logger.error("Failed to get latest metric for model",
|
|
model_id=model_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get latest metric: {str(e)}")
|
|
|
|
async def get_metrics_by_tenant_and_product(
|
|
self,
|
|
tenant_id: str,
|
|
product_name: str,
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> List[ModelPerformanceMetric]:
|
|
"""Get performance metrics for a tenant's product"""
|
|
try:
|
|
return await self.get_multi(
|
|
filters={
|
|
"tenant_id": tenant_id,
|
|
"product_name": product_name
|
|
},
|
|
skip=skip,
|
|
limit=limit,
|
|
order_by="measured_at",
|
|
order_desc=True
|
|
)
|
|
except Exception as e:
|
|
logger.error("Failed to get metrics by tenant and product",
|
|
tenant_id=tenant_id,
|
|
product_name=product_name,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get metrics: {str(e)}")
|
|
|
|
async def get_metrics_in_date_range(
|
|
self,
|
|
start_date: datetime,
|
|
end_date: datetime,
|
|
tenant_id: str = None,
|
|
model_id: str = None,
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> List[ModelPerformanceMetric]:
|
|
"""Get performance metrics within a date range"""
|
|
try:
|
|
# Build filters
|
|
table_name = self.model.__tablename__
|
|
conditions = ["measured_at >= :start_date", "measured_at <= :end_date"]
|
|
params = {"start_date": start_date, "end_date": end_date, "limit": limit, "skip": skip}
|
|
|
|
if tenant_id:
|
|
conditions.append("tenant_id = :tenant_id")
|
|
params["tenant_id"] = tenant_id
|
|
|
|
if model_id:
|
|
conditions.append("model_id = :model_id")
|
|
params["model_id"] = model_id
|
|
|
|
query_text = f"""
|
|
SELECT * FROM {table_name}
|
|
WHERE {' AND '.join(conditions)}
|
|
ORDER BY measured_at DESC
|
|
LIMIT :limit OFFSET :skip
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text), params)
|
|
|
|
# Convert rows to model objects
|
|
metrics = []
|
|
for row in result.fetchall():
|
|
record_dict = dict(row._mapping)
|
|
metric = self.model(**record_dict)
|
|
metrics.append(metric)
|
|
|
|
return metrics
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get metrics in date range",
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
error=str(e))
|
|
raise DatabaseError(f"Date range query failed: {str(e)}")
|
|
|
|
async def get_performance_trends(
|
|
self,
|
|
tenant_id: str,
|
|
product_name: str = None,
|
|
days: int = 30
|
|
) -> Dict[str, Any]:
|
|
"""Get performance trends for analysis"""
|
|
try:
|
|
start_date = datetime.now() - timedelta(days=days)
|
|
end_date = datetime.now()
|
|
|
|
# Build query for performance trends
|
|
conditions = ["tenant_id = :tenant_id", "measured_at >= :start_date"]
|
|
params = {"tenant_id": tenant_id, "start_date": start_date}
|
|
|
|
if product_name:
|
|
conditions.append("product_name = :product_name")
|
|
params["product_name"] = product_name
|
|
|
|
query_text = f"""
|
|
SELECT
|
|
product_name,
|
|
AVG(mae) as avg_mae,
|
|
AVG(mse) as avg_mse,
|
|
AVG(rmse) as avg_rmse,
|
|
AVG(mape) as avg_mape,
|
|
AVG(r2_score) as avg_r2_score,
|
|
AVG(accuracy_percentage) as avg_accuracy,
|
|
COUNT(*) as measurement_count,
|
|
MIN(measured_at) as first_measurement,
|
|
MAX(measured_at) as last_measurement
|
|
FROM model_performance_metrics
|
|
WHERE {' AND '.join(conditions)}
|
|
GROUP BY product_name
|
|
ORDER BY avg_accuracy DESC
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text), params)
|
|
|
|
trends = []
|
|
for row in result.fetchall():
|
|
trends.append({
|
|
"product_name": row.product_name,
|
|
"metrics": {
|
|
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
|
|
"avg_mse": float(row.avg_mse) if row.avg_mse else None,
|
|
"avg_rmse": float(row.avg_rmse) if row.avg_rmse else None,
|
|
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
|
|
"avg_r2_score": float(row.avg_r2_score) if row.avg_r2_score else None,
|
|
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
|
|
},
|
|
"measurement_count": int(row.measurement_count),
|
|
"period": {
|
|
"start": row.first_measurement.isoformat() if row.first_measurement else None,
|
|
"end": row.last_measurement.isoformat() if row.last_measurement else None,
|
|
"days": days
|
|
}
|
|
})
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"product_name": product_name,
|
|
"trends": trends,
|
|
"period_days": days,
|
|
"total_products": len(trends)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get performance trends",
|
|
tenant_id=tenant_id,
|
|
product_name=product_name,
|
|
error=str(e))
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"product_name": product_name,
|
|
"trends": [],
|
|
"period_days": days,
|
|
"total_products": 0
|
|
}
|
|
|
|
async def get_best_performing_models(
|
|
self,
|
|
tenant_id: str,
|
|
metric_type: str = "accuracy_percentage",
|
|
limit: int = 10
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get best performing models based on a specific metric"""
|
|
try:
|
|
# Validate metric type
|
|
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
|
|
if metric_type not in valid_metrics:
|
|
metric_type = "accuracy_percentage"
|
|
|
|
# For error metrics (mae, mse, rmse, mape), lower is better
|
|
# For performance metrics (r2_score, accuracy_percentage), higher is better
|
|
order_desc = metric_type in ["r2_score", "accuracy_percentage"]
|
|
order_direction = "DESC" if order_desc else "ASC"
|
|
|
|
query_text = f"""
|
|
SELECT DISTINCT ON (product_name, model_id)
|
|
model_id,
|
|
product_name,
|
|
{metric_type},
|
|
measured_at,
|
|
evaluation_samples
|
|
FROM model_performance_metrics
|
|
WHERE tenant_id = :tenant_id
|
|
AND {metric_type} IS NOT NULL
|
|
ORDER BY product_name, model_id, measured_at DESC, {metric_type} {order_direction}
|
|
LIMIT :limit
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text), {
|
|
"tenant_id": tenant_id,
|
|
"limit": limit
|
|
})
|
|
|
|
best_models = []
|
|
for row in result.fetchall():
|
|
best_models.append({
|
|
"model_id": row.model_id,
|
|
"product_name": row.product_name,
|
|
"metric_value": float(getattr(row, metric_type)),
|
|
"metric_type": metric_type,
|
|
"measured_at": row.measured_at.isoformat() if row.measured_at else None,
|
|
"evaluation_samples": int(row.evaluation_samples) if row.evaluation_samples else None
|
|
})
|
|
|
|
return best_models
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get best performing models",
|
|
tenant_id=tenant_id,
|
|
metric_type=metric_type,
|
|
error=str(e))
|
|
return []
|
|
|
|
async def cleanup_old_metrics(self, days_old: int = 180) -> int:
|
|
"""Clean up old performance metrics"""
|
|
return await self.cleanup_old_records(days_old=days_old)
|
|
|
|
async def get_metric_statistics(self, tenant_id: str) -> Dict[str, Any]:
|
|
"""Get performance metric statistics for a tenant"""
|
|
try:
|
|
# Get basic counts
|
|
total_metrics = await self.count(filters={"tenant_id": tenant_id})
|
|
|
|
# Get metrics by product using raw query
|
|
product_query = text("""
|
|
SELECT
|
|
product_name,
|
|
COUNT(*) as metric_count,
|
|
AVG(accuracy_percentage) as avg_accuracy
|
|
FROM model_performance_metrics
|
|
WHERE tenant_id = :tenant_id
|
|
GROUP BY product_name
|
|
ORDER BY avg_accuracy DESC
|
|
""")
|
|
|
|
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
|
|
product_stats = {}
|
|
|
|
for row in result.fetchall():
|
|
product_stats[row.product_name] = {
|
|
"metric_count": row.metric_count,
|
|
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
|
|
}
|
|
|
|
# Recent activity (metrics in last 7 days)
|
|
seven_days_ago = datetime.now() - timedelta(days=7)
|
|
recent_metrics = len(await self.get_records_by_date_range(
|
|
seven_days_ago,
|
|
datetime.now(),
|
|
limit=1000 # High limit to get accurate count
|
|
))
|
|
|
|
return {
|
|
"total_metrics": total_metrics,
|
|
"products_tracked": len(product_stats),
|
|
"metrics_by_product": product_stats,
|
|
"recent_metrics_7d": recent_metrics
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get metric statistics",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
return {
|
|
"total_metrics": 0,
|
|
"products_tracked": 0,
|
|
"metrics_by_product": {},
|
|
"recent_metrics_7d": 0
|
|
}
|
|
|
|
async def compare_model_performance(
|
|
self,
|
|
model_ids: List[str],
|
|
metric_type: str = "accuracy_percentage"
|
|
) -> Dict[str, Any]:
|
|
"""Compare performance between multiple models"""
|
|
try:
|
|
if not model_ids or len(model_ids) < 2:
|
|
return {"error": "At least 2 model IDs required for comparison"}
|
|
|
|
# Validate metric type
|
|
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
|
|
if metric_type not in valid_metrics:
|
|
metric_type = "accuracy_percentage"
|
|
|
|
model_ids_str = "', '".join(model_ids)
|
|
|
|
query_text = f"""
|
|
SELECT
|
|
model_id,
|
|
product_name,
|
|
AVG({metric_type}) as avg_metric,
|
|
MIN({metric_type}) as min_metric,
|
|
MAX({metric_type}) as max_metric,
|
|
COUNT(*) as measurement_count,
|
|
MAX(measured_at) as latest_measurement
|
|
FROM model_performance_metrics
|
|
WHERE model_id IN ('{model_ids_str}')
|
|
AND {metric_type} IS NOT NULL
|
|
GROUP BY model_id, product_name
|
|
ORDER BY avg_metric DESC
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text))
|
|
|
|
comparisons = []
|
|
for row in result.fetchall():
|
|
comparisons.append({
|
|
"model_id": row.model_id,
|
|
"product_name": row.product_name,
|
|
"avg_metric": float(row.avg_metric),
|
|
"min_metric": float(row.min_metric),
|
|
"max_metric": float(row.max_metric),
|
|
"measurement_count": int(row.measurement_count),
|
|
"latest_measurement": row.latest_measurement.isoformat() if row.latest_measurement else None
|
|
})
|
|
|
|
# Find best and worst performing models
|
|
if comparisons:
|
|
best_model = max(comparisons, key=lambda x: x["avg_metric"])
|
|
worst_model = min(comparisons, key=lambda x: x["avg_metric"])
|
|
else:
|
|
best_model = worst_model = None
|
|
|
|
return {
|
|
"metric_type": metric_type,
|
|
"models_compared": len(set(comp["model_id"] for comp in comparisons)),
|
|
"comparisons": comparisons,
|
|
"best_performing": best_model,
|
|
"worst_performing": worst_model
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to compare model performance",
|
|
model_ids=model_ids,
|
|
metric_type=metric_type,
|
|
error=str(e))
|
|
return {"error": f"Comparison failed: {str(e)}"} |