REFACTOR - Database logic
This commit is contained in:
433
services/training/app/repositories/performance_repository.py
Normal file
433
services/training/app/repositories/performance_repository.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Performance Repository
|
||||
Repository for model performance metrics operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from app.models.training import ModelPerformanceMetric
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class PerformanceRepository(TrainingBaseRepository):
|
||||
"""Repository for model performance metrics operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 900):
|
||||
# Performance metrics are relatively stable, longer cache time (15 minutes)
|
||||
super().__init__(ModelPerformanceMetric, session, cache_ttl)
|
||||
|
||||
async def create_performance_metric(self, metric_data: Dict[str, Any]) -> ModelPerformanceMetric:
|
||||
"""Create a new performance metric record"""
|
||||
try:
|
||||
# Validate metric data
|
||||
validation_result = self._validate_training_data(
|
||||
metric_data,
|
||||
["model_id", "tenant_id", "product_name"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid metric data: {validation_result['errors']}")
|
||||
|
||||
# Set measurement timestamp if not provided
|
||||
if "measured_at" not in metric_data:
|
||||
metric_data["measured_at"] = datetime.now()
|
||||
|
||||
# Create metric record
|
||||
metric = await self.create(metric_data)
|
||||
|
||||
logger.info("Performance metric created",
|
||||
model_id=metric.model_id,
|
||||
tenant_id=metric.tenant_id,
|
||||
product_name=metric.product_name)
|
||||
|
||||
return metric
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create performance metric",
|
||||
model_id=metric_data.get("model_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create metric: {str(e)}")
|
||||
|
||||
async def get_metrics_by_model(
|
||||
self,
|
||||
model_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelPerformanceMetric]:
|
||||
"""Get all performance metrics for a model"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={"model_id": model_id},
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="measured_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get metrics by model",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get metrics: {str(e)}")
|
||||
|
||||
async def get_latest_metric_for_model(self, model_id: str) -> Optional[ModelPerformanceMetric]:
|
||||
"""Get the latest performance metric for a model"""
|
||||
try:
|
||||
metrics = await self.get_multi(
|
||||
filters={"model_id": model_id},
|
||||
limit=1,
|
||||
order_by="measured_at",
|
||||
order_desc=True
|
||||
)
|
||||
return metrics[0] if metrics else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get latest metric for model",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get latest metric: {str(e)}")
|
||||
|
||||
async def get_metrics_by_tenant_and_product(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelPerformanceMetric]:
|
||||
"""Get performance metrics for a tenant's product"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name
|
||||
},
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="measured_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get metrics by tenant and product",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get metrics: {str(e)}")
|
||||
|
||||
async def get_metrics_in_date_range(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
tenant_id: str = None,
|
||||
model_id: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelPerformanceMetric]:
|
||||
"""Get performance metrics within a date range"""
|
||||
try:
|
||||
# Build filters
|
||||
table_name = self.model.__tablename__
|
||||
conditions = ["measured_at >= :start_date", "measured_at <= :end_date"]
|
||||
params = {"start_date": start_date, "end_date": end_date, "limit": limit, "skip": skip}
|
||||
|
||||
if tenant_id:
|
||||
conditions.append("tenant_id = :tenant_id")
|
||||
params["tenant_id"] = tenant_id
|
||||
|
||||
if model_id:
|
||||
conditions.append("model_id = :model_id")
|
||||
params["model_id"] = model_id
|
||||
|
||||
query_text = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE {' AND '.join(conditions)}
|
||||
ORDER BY measured_at DESC
|
||||
LIMIT :limit OFFSET :skip
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
|
||||
# Convert rows to model objects
|
||||
metrics = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
metric = self.model(**record_dict)
|
||||
metrics.append(metric)
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get metrics in date range",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Date range query failed: {str(e)}")
|
||||
|
||||
async def get_performance_trends(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str = None,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get performance trends for analysis"""
|
||||
try:
|
||||
start_date = datetime.now() - timedelta(days=days)
|
||||
end_date = datetime.now()
|
||||
|
||||
# Build query for performance trends
|
||||
conditions = ["tenant_id = :tenant_id", "measured_at >= :start_date"]
|
||||
params = {"tenant_id": tenant_id, "start_date": start_date}
|
||||
|
||||
if product_name:
|
||||
conditions.append("product_name = :product_name")
|
||||
params["product_name"] = product_name
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
product_name,
|
||||
AVG(mae) as avg_mae,
|
||||
AVG(mse) as avg_mse,
|
||||
AVG(rmse) as avg_rmse,
|
||||
AVG(mape) as avg_mape,
|
||||
AVG(r2_score) as avg_r2_score,
|
||||
AVG(accuracy_percentage) as avg_accuracy,
|
||||
COUNT(*) as measurement_count,
|
||||
MIN(measured_at) as first_measurement,
|
||||
MAX(measured_at) as last_measurement
|
||||
FROM model_performance_metrics
|
||||
WHERE {' AND '.join(conditions)}
|
||||
GROUP BY product_name
|
||||
ORDER BY avg_accuracy DESC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
|
||||
trends = []
|
||||
for row in result.fetchall():
|
||||
trends.append({
|
||||
"product_name": row.product_name,
|
||||
"metrics": {
|
||||
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
|
||||
"avg_mse": float(row.avg_mse) if row.avg_mse else None,
|
||||
"avg_rmse": float(row.avg_rmse) if row.avg_rmse else None,
|
||||
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
|
||||
"avg_r2_score": float(row.avg_r2_score) if row.avg_r2_score else None,
|
||||
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
|
||||
},
|
||||
"measurement_count": int(row.measurement_count),
|
||||
"period": {
|
||||
"start": row.first_measurement.isoformat() if row.first_measurement else None,
|
||||
"end": row.last_measurement.isoformat() if row.last_measurement else None,
|
||||
"days": days
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"trends": trends,
|
||||
"period_days": days,
|
||||
"total_products": len(trends)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get performance trends",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
error=str(e))
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"trends": [],
|
||||
"period_days": days,
|
||||
"total_products": 0
|
||||
}
|
||||
|
||||
async def get_best_performing_models(
|
||||
self,
|
||||
tenant_id: str,
|
||||
metric_type: str = "accuracy_percentage",
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get best performing models based on a specific metric"""
|
||||
try:
|
||||
# Validate metric type
|
||||
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
|
||||
if metric_type not in valid_metrics:
|
||||
metric_type = "accuracy_percentage"
|
||||
|
||||
# For error metrics (mae, mse, rmse, mape), lower is better
|
||||
# For performance metrics (r2_score, accuracy_percentage), higher is better
|
||||
order_desc = metric_type in ["r2_score", "accuracy_percentage"]
|
||||
order_direction = "DESC" if order_desc else "ASC"
|
||||
|
||||
query_text = f"""
|
||||
SELECT DISTINCT ON (product_name, model_id)
|
||||
model_id,
|
||||
product_name,
|
||||
{metric_type},
|
||||
measured_at,
|
||||
evaluation_samples
|
||||
FROM model_performance_metrics
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND {metric_type} IS NOT NULL
|
||||
ORDER BY product_name, model_id, measured_at DESC, {metric_type} {order_direction}
|
||||
LIMIT :limit
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"tenant_id": tenant_id,
|
||||
"limit": limit
|
||||
})
|
||||
|
||||
best_models = []
|
||||
for row in result.fetchall():
|
||||
best_models.append({
|
||||
"model_id": row.model_id,
|
||||
"product_name": row.product_name,
|
||||
"metric_value": float(getattr(row, metric_type)),
|
||||
"metric_type": metric_type,
|
||||
"measured_at": row.measured_at.isoformat() if row.measured_at else None,
|
||||
"evaluation_samples": int(row.evaluation_samples) if row.evaluation_samples else None
|
||||
})
|
||||
|
||||
return best_models
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get best performing models",
|
||||
tenant_id=tenant_id,
|
||||
metric_type=metric_type,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def cleanup_old_metrics(self, days_old: int = 180) -> int:
|
||||
"""Clean up old performance metrics"""
|
||||
return await self.cleanup_old_records(days_old=days_old)
|
||||
|
||||
async def get_metric_statistics(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get performance metric statistics for a tenant"""
|
||||
try:
|
||||
# Get basic counts
|
||||
total_metrics = await self.count(filters={"tenant_id": tenant_id})
|
||||
|
||||
# Get metrics by product using raw query
|
||||
product_query = text("""
|
||||
SELECT
|
||||
product_name,
|
||||
COUNT(*) as metric_count,
|
||||
AVG(accuracy_percentage) as avg_accuracy
|
||||
FROM model_performance_metrics
|
||||
WHERE tenant_id = :tenant_id
|
||||
GROUP BY product_name
|
||||
ORDER BY avg_accuracy DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
|
||||
product_stats = {}
|
||||
|
||||
for row in result.fetchall():
|
||||
product_stats[row.product_name] = {
|
||||
"metric_count": row.metric_count,
|
||||
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
|
||||
}
|
||||
|
||||
# Recent activity (metrics in last 7 days)
|
||||
seven_days_ago = datetime.now() - timedelta(days=7)
|
||||
recent_metrics = len(await self.get_records_by_date_range(
|
||||
seven_days_ago,
|
||||
datetime.now(),
|
||||
limit=1000 # High limit to get accurate count
|
||||
))
|
||||
|
||||
return {
|
||||
"total_metrics": total_metrics,
|
||||
"products_tracked": len(product_stats),
|
||||
"metrics_by_product": product_stats,
|
||||
"recent_metrics_7d": recent_metrics
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get metric statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_metrics": 0,
|
||||
"products_tracked": 0,
|
||||
"metrics_by_product": {},
|
||||
"recent_metrics_7d": 0
|
||||
}
|
||||
|
||||
async def compare_model_performance(
|
||||
self,
|
||||
model_ids: List[str],
|
||||
metric_type: str = "accuracy_percentage"
|
||||
) -> Dict[str, Any]:
|
||||
"""Compare performance between multiple models"""
|
||||
try:
|
||||
if not model_ids or len(model_ids) < 2:
|
||||
return {"error": "At least 2 model IDs required for comparison"}
|
||||
|
||||
# Validate metric type
|
||||
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
|
||||
if metric_type not in valid_metrics:
|
||||
metric_type = "accuracy_percentage"
|
||||
|
||||
model_ids_str = "', '".join(model_ids)
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
model_id,
|
||||
product_name,
|
||||
AVG({metric_type}) as avg_metric,
|
||||
MIN({metric_type}) as min_metric,
|
||||
MAX({metric_type}) as max_metric,
|
||||
COUNT(*) as measurement_count,
|
||||
MAX(measured_at) as latest_measurement
|
||||
FROM model_performance_metrics
|
||||
WHERE model_id IN ('{model_ids_str}')
|
||||
AND {metric_type} IS NOT NULL
|
||||
GROUP BY model_id, product_name
|
||||
ORDER BY avg_metric DESC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text))
|
||||
|
||||
comparisons = []
|
||||
for row in result.fetchall():
|
||||
comparisons.append({
|
||||
"model_id": row.model_id,
|
||||
"product_name": row.product_name,
|
||||
"avg_metric": float(row.avg_metric),
|
||||
"min_metric": float(row.min_metric),
|
||||
"max_metric": float(row.max_metric),
|
||||
"measurement_count": int(row.measurement_count),
|
||||
"latest_measurement": row.latest_measurement.isoformat() if row.latest_measurement else None
|
||||
})
|
||||
|
||||
# Find best and worst performing models
|
||||
if comparisons:
|
||||
best_model = max(comparisons, key=lambda x: x["avg_metric"])
|
||||
worst_model = min(comparisons, key=lambda x: x["avg_metric"])
|
||||
else:
|
||||
best_model = worst_model = None
|
||||
|
||||
return {
|
||||
"metric_type": metric_type,
|
||||
"models_compared": len(set(comp["model_id"] for comp in comparisons)),
|
||||
"comparisons": comparisons,
|
||||
"best_performing": best_model,
|
||||
"worst_performing": worst_model
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to compare model performance",
|
||||
model_ids=model_ids,
|
||||
metric_type=metric_type,
|
||||
error=str(e))
|
||||
return {"error": f"Comparison failed: {str(e)}"}
|
||||
Reference in New Issue
Block a user