Files
bakery-ia/services/training/app/repositories/performance_repository.py

433 lines
17 KiB
Python
Raw Normal View History

2025-08-08 09:08:41 +02:00
"""
Performance Repository
Repository for model performance metrics operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelPerformanceMetric
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class PerformanceRepository(TrainingBaseRepository):
"""Repository for model performance metrics operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 900):
# Performance metrics are relatively stable, longer cache time (15 minutes)
super().__init__(ModelPerformanceMetric, session, cache_ttl)
async def create_performance_metric(self, metric_data: Dict[str, Any]) -> ModelPerformanceMetric:
"""Create a new performance metric record"""
try:
# Validate metric data
validation_result = self._validate_training_data(
metric_data,
2025-08-14 16:47:34 +02:00
["model_id", "tenant_id", "inventory_product_id"]
2025-08-08 09:08:41 +02:00
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid metric data: {validation_result['errors']}")
# Set measurement timestamp if not provided
if "measured_at" not in metric_data:
metric_data["measured_at"] = datetime.now()
# Create metric record
metric = await self.create(metric_data)
logger.info("Performance metric created",
model_id=metric.model_id,
tenant_id=metric.tenant_id,
2025-08-14 16:47:34 +02:00
inventory_product_id=str(metric.inventory_product_id))
2025-08-08 09:08:41 +02:00
return metric
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create performance metric",
model_id=metric_data.get("model_id"),
error=str(e))
raise DatabaseError(f"Failed to create metric: {str(e)}")
async def get_metrics_by_model(
self,
model_id: str,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get all performance metrics for a model"""
try:
return await self.get_multi(
filters={"model_id": model_id},
skip=skip,
limit=limit,
order_by="measured_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get metrics by model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get metrics: {str(e)}")
async def get_latest_metric_for_model(self, model_id: str) -> Optional[ModelPerformanceMetric]:
"""Get the latest performance metric for a model"""
try:
metrics = await self.get_multi(
filters={"model_id": model_id},
limit=1,
order_by="measured_at",
order_desc=True
)
return metrics[0] if metrics else None
except Exception as e:
logger.error("Failed to get latest metric for model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get latest metric: {str(e)}")
async def get_metrics_by_tenant_and_product(
self,
tenant_id: str,
2025-08-14 16:47:34 +02:00
inventory_product_id: str,
2025-08-08 09:08:41 +02:00
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get performance metrics for a tenant's product"""
try:
return await self.get_multi(
filters={
"tenant_id": tenant_id,
2025-08-14 16:47:34 +02:00
"inventory_product_id": inventory_product_id
2025-08-08 09:08:41 +02:00
},
skip=skip,
limit=limit,
order_by="measured_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get metrics by tenant and product",
tenant_id=tenant_id,
2025-08-14 16:47:34 +02:00
inventory_product_id=inventory_product_id,
2025-08-08 09:08:41 +02:00
error=str(e))
raise DatabaseError(f"Failed to get metrics: {str(e)}")
async def get_metrics_in_date_range(
self,
start_date: datetime,
end_date: datetime,
tenant_id: str = None,
model_id: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get performance metrics within a date range"""
try:
# Build filters
table_name = self.model.__tablename__
conditions = ["measured_at >= :start_date", "measured_at <= :end_date"]
params = {"start_date": start_date, "end_date": end_date, "limit": limit, "skip": skip}
if tenant_id:
conditions.append("tenant_id = :tenant_id")
params["tenant_id"] = tenant_id
if model_id:
conditions.append("model_id = :model_id")
params["model_id"] = model_id
query_text = f"""
SELECT * FROM {table_name}
WHERE {' AND '.join(conditions)}
ORDER BY measured_at DESC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), params)
# Convert rows to model objects
metrics = []
for row in result.fetchall():
record_dict = dict(row._mapping)
metric = self.model(**record_dict)
metrics.append(metric)
return metrics
except Exception as e:
logger.error("Failed to get metrics in date range",
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
async def get_performance_trends(
self,
tenant_id: str,
2025-08-14 16:47:34 +02:00
inventory_product_id: str = None,
2025-08-08 09:08:41 +02:00
days: int = 30
) -> Dict[str, Any]:
"""Get performance trends for analysis"""
try:
start_date = datetime.now() - timedelta(days=days)
end_date = datetime.now()
# Build query for performance trends
conditions = ["tenant_id = :tenant_id", "measured_at >= :start_date"]
params = {"tenant_id": tenant_id, "start_date": start_date}
2025-08-14 16:47:34 +02:00
if inventory_product_id:
conditions.append("inventory_product_id = :inventory_product_id")
params["inventory_product_id"] = inventory_product_id
2025-08-08 09:08:41 +02:00
query_text = f"""
SELECT
2025-08-14 16:47:34 +02:00
inventory_product_id,
2025-08-08 09:08:41 +02:00
AVG(mae) as avg_mae,
AVG(mse) as avg_mse,
AVG(rmse) as avg_rmse,
AVG(mape) as avg_mape,
AVG(r2_score) as avg_r2_score,
AVG(accuracy_percentage) as avg_accuracy,
COUNT(*) as measurement_count,
MIN(measured_at) as first_measurement,
MAX(measured_at) as last_measurement
FROM model_performance_metrics
WHERE {' AND '.join(conditions)}
2025-08-14 16:47:34 +02:00
GROUP BY inventory_product_id
2025-08-08 09:08:41 +02:00
ORDER BY avg_accuracy DESC
"""
result = await self.session.execute(text(query_text), params)
trends = []
for row in result.fetchall():
trends.append({
2025-08-14 16:47:34 +02:00
"inventory_product_id": row.inventory_product_id,
2025-08-08 09:08:41 +02:00
"metrics": {
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
"avg_mse": float(row.avg_mse) if row.avg_mse else None,
"avg_rmse": float(row.avg_rmse) if row.avg_rmse else None,
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
"avg_r2_score": float(row.avg_r2_score) if row.avg_r2_score else None,
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
},
"measurement_count": int(row.measurement_count),
"period": {
"start": row.first_measurement.isoformat() if row.first_measurement else None,
"end": row.last_measurement.isoformat() if row.last_measurement else None,
"days": days
}
})
return {
"tenant_id": tenant_id,
2025-08-14 16:47:34 +02:00
"inventory_product_id": inventory_product_id,
2025-08-08 09:08:41 +02:00
"trends": trends,
"period_days": days,
"total_products": len(trends)
}
except Exception as e:
logger.error("Failed to get performance trends",
tenant_id=tenant_id,
2025-08-14 16:47:34 +02:00
inventory_product_id=inventory_product_id,
2025-08-08 09:08:41 +02:00
error=str(e))
return {
"tenant_id": tenant_id,
2025-08-14 16:47:34 +02:00
"inventory_product_id": inventory_product_id,
2025-08-08 09:08:41 +02:00
"trends": [],
"period_days": days,
"total_products": 0
}
async def get_best_performing_models(
self,
tenant_id: str,
metric_type: str = "accuracy_percentage",
limit: int = 10
) -> List[Dict[str, Any]]:
"""Get best performing models based on a specific metric"""
try:
# Validate metric type
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
if metric_type not in valid_metrics:
metric_type = "accuracy_percentage"
# For error metrics (mae, mse, rmse, mape), lower is better
# For performance metrics (r2_score, accuracy_percentage), higher is better
order_desc = metric_type in ["r2_score", "accuracy_percentage"]
order_direction = "DESC" if order_desc else "ASC"
query_text = f"""
2025-08-14 16:47:34 +02:00
SELECT DISTINCT ON (inventory_product_id, model_id)
2025-08-08 09:08:41 +02:00
model_id,
2025-08-14 16:47:34 +02:00
inventory_product_id,
2025-08-08 09:08:41 +02:00
{metric_type},
measured_at,
evaluation_samples
FROM model_performance_metrics
WHERE tenant_id = :tenant_id
AND {metric_type} IS NOT NULL
2025-08-14 16:47:34 +02:00
ORDER BY inventory_product_id, model_id, measured_at DESC, {metric_type} {order_direction}
2025-08-08 09:08:41 +02:00
LIMIT :limit
"""
result = await self.session.execute(text(query_text), {
"tenant_id": tenant_id,
"limit": limit
})
best_models = []
for row in result.fetchall():
best_models.append({
"model_id": row.model_id,
2025-08-14 16:47:34 +02:00
"inventory_product_id": row.inventory_product_id,
2025-08-08 09:08:41 +02:00
"metric_value": float(getattr(row, metric_type)),
"metric_type": metric_type,
"measured_at": row.measured_at.isoformat() if row.measured_at else None,
"evaluation_samples": int(row.evaluation_samples) if row.evaluation_samples else None
})
return best_models
except Exception as e:
logger.error("Failed to get best performing models",
tenant_id=tenant_id,
metric_type=metric_type,
error=str(e))
return []
async def cleanup_old_metrics(self, days_old: int = 180) -> int:
"""Clean up old performance metrics"""
return await self.cleanup_old_records(days_old=days_old)
async def get_metric_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get performance metric statistics for a tenant"""
try:
# Get basic counts
total_metrics = await self.count(filters={"tenant_id": tenant_id})
# Get metrics by product using raw query
product_query = text("""
SELECT
2025-08-14 16:47:34 +02:00
inventory_product_id,
2025-08-08 09:08:41 +02:00
COUNT(*) as metric_count,
AVG(accuracy_percentage) as avg_accuracy
FROM model_performance_metrics
WHERE tenant_id = :tenant_id
2025-08-14 16:47:34 +02:00
GROUP BY inventory_product_id
2025-08-08 09:08:41 +02:00
ORDER BY avg_accuracy DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {}
for row in result.fetchall():
2025-08-14 16:47:34 +02:00
product_stats[row.inventory_product_id] = {
2025-08-08 09:08:41 +02:00
"metric_count": row.metric_count,
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
}
# Recent activity (metrics in last 7 days)
seven_days_ago = datetime.now() - timedelta(days=7)
recent_metrics = len(await self.get_records_by_date_range(
seven_days_ago,
datetime.now(),
limit=1000 # High limit to get accurate count
))
return {
"total_metrics": total_metrics,
"products_tracked": len(product_stats),
"metrics_by_product": product_stats,
"recent_metrics_7d": recent_metrics
}
except Exception as e:
logger.error("Failed to get metric statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_metrics": 0,
"products_tracked": 0,
"metrics_by_product": {},
"recent_metrics_7d": 0
}
async def compare_model_performance(
self,
model_ids: List[str],
metric_type: str = "accuracy_percentage"
) -> Dict[str, Any]:
"""Compare performance between multiple models"""
try:
if not model_ids or len(model_ids) < 2:
return {"error": "At least 2 model IDs required for comparison"}
# Validate metric type
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
if metric_type not in valid_metrics:
metric_type = "accuracy_percentage"
model_ids_str = "', '".join(model_ids)
query_text = f"""
SELECT
model_id,
2025-08-14 16:47:34 +02:00
inventory_product_id,
2025-08-08 09:08:41 +02:00
AVG({metric_type}) as avg_metric,
MIN({metric_type}) as min_metric,
MAX({metric_type}) as max_metric,
COUNT(*) as measurement_count,
MAX(measured_at) as latest_measurement
FROM model_performance_metrics
WHERE model_id IN ('{model_ids_str}')
AND {metric_type} IS NOT NULL
2025-08-14 16:47:34 +02:00
GROUP BY model_id, inventory_product_id
2025-08-08 09:08:41 +02:00
ORDER BY avg_metric DESC
"""
result = await self.session.execute(text(query_text))
comparisons = []
for row in result.fetchall():
comparisons.append({
"model_id": row.model_id,
2025-08-14 16:47:34 +02:00
"inventory_product_id": row.inventory_product_id,
2025-08-08 09:08:41 +02:00
"avg_metric": float(row.avg_metric),
"min_metric": float(row.min_metric),
"max_metric": float(row.max_metric),
"measurement_count": int(row.measurement_count),
"latest_measurement": row.latest_measurement.isoformat() if row.latest_measurement else None
})
# Find best and worst performing models
if comparisons:
best_model = max(comparisons, key=lambda x: x["avg_metric"])
worst_model = min(comparisons, key=lambda x: x["avg_metric"])
else:
best_model = worst_model = None
return {
"metric_type": metric_type,
"models_compared": len(set(comp["model_id"] for comp in comparisons)),
"comparisons": comparisons,
"best_performing": best_model,
"worst_performing": worst_model
}
except Exception as e:
logger.error("Failed to compare model performance",
model_ids=model_ids,
metric_type=metric_type,
error=str(e))
return {"error": f"Comparison failed: {str(e)}"}