Files
bakery-ia/services/training/app/repositories/performance_repository.py
2025-08-08 09:08:41 +02:00

433 lines
17 KiB
Python

"""
Performance Repository
Repository for model performance metrics operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelPerformanceMetric
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class PerformanceRepository(TrainingBaseRepository):
"""Repository for model performance metrics operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 900):
# Performance metrics are relatively stable, longer cache time (15 minutes)
super().__init__(ModelPerformanceMetric, session, cache_ttl)
async def create_performance_metric(self, metric_data: Dict[str, Any]) -> ModelPerformanceMetric:
"""Create a new performance metric record"""
try:
# Validate metric data
validation_result = self._validate_training_data(
metric_data,
["model_id", "tenant_id", "product_name"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid metric data: {validation_result['errors']}")
# Set measurement timestamp if not provided
if "measured_at" not in metric_data:
metric_data["measured_at"] = datetime.now()
# Create metric record
metric = await self.create(metric_data)
logger.info("Performance metric created",
model_id=metric.model_id,
tenant_id=metric.tenant_id,
product_name=metric.product_name)
return metric
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create performance metric",
model_id=metric_data.get("model_id"),
error=str(e))
raise DatabaseError(f"Failed to create metric: {str(e)}")
async def get_metrics_by_model(
self,
model_id: str,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get all performance metrics for a model"""
try:
return await self.get_multi(
filters={"model_id": model_id},
skip=skip,
limit=limit,
order_by="measured_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get metrics by model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get metrics: {str(e)}")
async def get_latest_metric_for_model(self, model_id: str) -> Optional[ModelPerformanceMetric]:
"""Get the latest performance metric for a model"""
try:
metrics = await self.get_multi(
filters={"model_id": model_id},
limit=1,
order_by="measured_at",
order_desc=True
)
return metrics[0] if metrics else None
except Exception as e:
logger.error("Failed to get latest metric for model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get latest metric: {str(e)}")
async def get_metrics_by_tenant_and_product(
self,
tenant_id: str,
product_name: str,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get performance metrics for a tenant's product"""
try:
return await self.get_multi(
filters={
"tenant_id": tenant_id,
"product_name": product_name
},
skip=skip,
limit=limit,
order_by="measured_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get metrics by tenant and product",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
raise DatabaseError(f"Failed to get metrics: {str(e)}")
async def get_metrics_in_date_range(
self,
start_date: datetime,
end_date: datetime,
tenant_id: str = None,
model_id: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get performance metrics within a date range"""
try:
# Build filters
table_name = self.model.__tablename__
conditions = ["measured_at >= :start_date", "measured_at <= :end_date"]
params = {"start_date": start_date, "end_date": end_date, "limit": limit, "skip": skip}
if tenant_id:
conditions.append("tenant_id = :tenant_id")
params["tenant_id"] = tenant_id
if model_id:
conditions.append("model_id = :model_id")
params["model_id"] = model_id
query_text = f"""
SELECT * FROM {table_name}
WHERE {' AND '.join(conditions)}
ORDER BY measured_at DESC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), params)
# Convert rows to model objects
metrics = []
for row in result.fetchall():
record_dict = dict(row._mapping)
metric = self.model(**record_dict)
metrics.append(metric)
return metrics
except Exception as e:
logger.error("Failed to get metrics in date range",
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
async def get_performance_trends(
self,
tenant_id: str,
product_name: str = None,
days: int = 30
) -> Dict[str, Any]:
"""Get performance trends for analysis"""
try:
start_date = datetime.now() - timedelta(days=days)
end_date = datetime.now()
# Build query for performance trends
conditions = ["tenant_id = :tenant_id", "measured_at >= :start_date"]
params = {"tenant_id": tenant_id, "start_date": start_date}
if product_name:
conditions.append("product_name = :product_name")
params["product_name"] = product_name
query_text = f"""
SELECT
product_name,
AVG(mae) as avg_mae,
AVG(mse) as avg_mse,
AVG(rmse) as avg_rmse,
AVG(mape) as avg_mape,
AVG(r2_score) as avg_r2_score,
AVG(accuracy_percentage) as avg_accuracy,
COUNT(*) as measurement_count,
MIN(measured_at) as first_measurement,
MAX(measured_at) as last_measurement
FROM model_performance_metrics
WHERE {' AND '.join(conditions)}
GROUP BY product_name
ORDER BY avg_accuracy DESC
"""
result = await self.session.execute(text(query_text), params)
trends = []
for row in result.fetchall():
trends.append({
"product_name": row.product_name,
"metrics": {
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
"avg_mse": float(row.avg_mse) if row.avg_mse else None,
"avg_rmse": float(row.avg_rmse) if row.avg_rmse else None,
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
"avg_r2_score": float(row.avg_r2_score) if row.avg_r2_score else None,
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
},
"measurement_count": int(row.measurement_count),
"period": {
"start": row.first_measurement.isoformat() if row.first_measurement else None,
"end": row.last_measurement.isoformat() if row.last_measurement else None,
"days": days
}
})
return {
"tenant_id": tenant_id,
"product_name": product_name,
"trends": trends,
"period_days": days,
"total_products": len(trends)
}
except Exception as e:
logger.error("Failed to get performance trends",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
return {
"tenant_id": tenant_id,
"product_name": product_name,
"trends": [],
"period_days": days,
"total_products": 0
}
async def get_best_performing_models(
self,
tenant_id: str,
metric_type: str = "accuracy_percentage",
limit: int = 10
) -> List[Dict[str, Any]]:
"""Get best performing models based on a specific metric"""
try:
# Validate metric type
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
if metric_type not in valid_metrics:
metric_type = "accuracy_percentage"
# For error metrics (mae, mse, rmse, mape), lower is better
# For performance metrics (r2_score, accuracy_percentage), higher is better
order_desc = metric_type in ["r2_score", "accuracy_percentage"]
order_direction = "DESC" if order_desc else "ASC"
query_text = f"""
SELECT DISTINCT ON (product_name, model_id)
model_id,
product_name,
{metric_type},
measured_at,
evaluation_samples
FROM model_performance_metrics
WHERE tenant_id = :tenant_id
AND {metric_type} IS NOT NULL
ORDER BY product_name, model_id, measured_at DESC, {metric_type} {order_direction}
LIMIT :limit
"""
result = await self.session.execute(text(query_text), {
"tenant_id": tenant_id,
"limit": limit
})
best_models = []
for row in result.fetchall():
best_models.append({
"model_id": row.model_id,
"product_name": row.product_name,
"metric_value": float(getattr(row, metric_type)),
"metric_type": metric_type,
"measured_at": row.measured_at.isoformat() if row.measured_at else None,
"evaluation_samples": int(row.evaluation_samples) if row.evaluation_samples else None
})
return best_models
except Exception as e:
logger.error("Failed to get best performing models",
tenant_id=tenant_id,
metric_type=metric_type,
error=str(e))
return []
async def cleanup_old_metrics(self, days_old: int = 180) -> int:
"""Clean up old performance metrics"""
return await self.cleanup_old_records(days_old=days_old)
async def get_metric_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get performance metric statistics for a tenant"""
try:
# Get basic counts
total_metrics = await self.count(filters={"tenant_id": tenant_id})
# Get metrics by product using raw query
product_query = text("""
SELECT
product_name,
COUNT(*) as metric_count,
AVG(accuracy_percentage) as avg_accuracy
FROM model_performance_metrics
WHERE tenant_id = :tenant_id
GROUP BY product_name
ORDER BY avg_accuracy DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {}
for row in result.fetchall():
product_stats[row.product_name] = {
"metric_count": row.metric_count,
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
}
# Recent activity (metrics in last 7 days)
seven_days_ago = datetime.now() - timedelta(days=7)
recent_metrics = len(await self.get_records_by_date_range(
seven_days_ago,
datetime.now(),
limit=1000 # High limit to get accurate count
))
return {
"total_metrics": total_metrics,
"products_tracked": len(product_stats),
"metrics_by_product": product_stats,
"recent_metrics_7d": recent_metrics
}
except Exception as e:
logger.error("Failed to get metric statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_metrics": 0,
"products_tracked": 0,
"metrics_by_product": {},
"recent_metrics_7d": 0
}
async def compare_model_performance(
self,
model_ids: List[str],
metric_type: str = "accuracy_percentage"
) -> Dict[str, Any]:
"""Compare performance between multiple models"""
try:
if not model_ids or len(model_ids) < 2:
return {"error": "At least 2 model IDs required for comparison"}
# Validate metric type
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
if metric_type not in valid_metrics:
metric_type = "accuracy_percentage"
model_ids_str = "', '".join(model_ids)
query_text = f"""
SELECT
model_id,
product_name,
AVG({metric_type}) as avg_metric,
MIN({metric_type}) as min_metric,
MAX({metric_type}) as max_metric,
COUNT(*) as measurement_count,
MAX(measured_at) as latest_measurement
FROM model_performance_metrics
WHERE model_id IN ('{model_ids_str}')
AND {metric_type} IS NOT NULL
GROUP BY model_id, product_name
ORDER BY avg_metric DESC
"""
result = await self.session.execute(text(query_text))
comparisons = []
for row in result.fetchall():
comparisons.append({
"model_id": row.model_id,
"product_name": row.product_name,
"avg_metric": float(row.avg_metric),
"min_metric": float(row.min_metric),
"max_metric": float(row.max_metric),
"measurement_count": int(row.measurement_count),
"latest_measurement": row.latest_measurement.isoformat() if row.latest_measurement else None
})
# Find best and worst performing models
if comparisons:
best_model = max(comparisons, key=lambda x: x["avg_metric"])
worst_model = min(comparisons, key=lambda x: x["avg_metric"])
else:
best_model = worst_model = None
return {
"metric_type": metric_type,
"models_compared": len(set(comp["model_id"] for comp in comparisons)),
"comparisons": comparisons,
"best_performing": best_model,
"worst_performing": worst_model
}
except Exception as e:
logger.error("Failed to compare model performance",
model_ids=model_ids,
metric_type=metric_type,
error=str(e))
return {"error": f"Comparison failed: {str(e)}"}