Improve backend

This commit is contained in:
Urtzi Alfaro
2025-11-18 07:17:17 +01:00
parent d36f2ab9af
commit 5c45164c8e
61 changed files with 9846 additions and 495 deletions

View File

@@ -0,0 +1,480 @@
# ================================================================
# services/forecasting/app/services/historical_validation_service.py
# ================================================================
"""
Historical Validation Service
Handles validation backfill when historical sales data is uploaded late.
Detects gaps in validation coverage and automatically triggers validation
for periods where forecasts exist but haven't been validated yet.
"""
from typing import Dict, Any, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, func, Date, or_
from datetime import datetime, timedelta, timezone, date
import structlog
import uuid
from app.models.forecasts import Forecast
from app.models.validation_run import ValidationRun
from app.models.sales_data_update import SalesDataUpdate
from app.services.validation_service import ValidationService
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class HistoricalValidationService:
"""Service for backfilling historical validation when sales data arrives late"""
def __init__(self, db_session: AsyncSession):
self.db = db_session
self.validation_service = ValidationService(db_session)
async def detect_validation_gaps(
self,
tenant_id: uuid.UUID,
lookback_days: int = 90
) -> List[Dict[str, Any]]:
"""
Detect date ranges where forecasts exist but haven't been validated
Args:
tenant_id: Tenant identifier
lookback_days: How far back to check (default 90 days)
Returns:
List of gap periods with date ranges
"""
try:
end_date = datetime.now(timezone.utc)
start_date = end_date - timedelta(days=lookback_days)
logger.info(
"Detecting validation gaps",
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat()
)
# Get all dates with forecasts
forecast_query = select(
func.cast(Forecast.forecast_date, Date).label('forecast_date')
).where(
and_(
Forecast.tenant_id == tenant_id,
Forecast.forecast_date >= start_date,
Forecast.forecast_date <= end_date
)
).group_by(
func.cast(Forecast.forecast_date, Date)
).order_by(
func.cast(Forecast.forecast_date, Date)
)
forecast_result = await self.db.execute(forecast_query)
forecast_dates = {row.forecast_date for row in forecast_result.fetchall()}
if not forecast_dates:
logger.info("No forecasts found in lookback period", tenant_id=tenant_id)
return []
# Get all dates that have been validated
validation_query = select(
func.cast(ValidationRun.validation_start_date, Date).label('validated_date')
).where(
and_(
ValidationRun.tenant_id == tenant_id,
ValidationRun.status == "completed",
ValidationRun.validation_start_date >= start_date,
ValidationRun.validation_end_date <= end_date
)
).group_by(
func.cast(ValidationRun.validation_start_date, Date)
)
validation_result = await self.db.execute(validation_query)
validated_dates = {row.validated_date for row in validation_result.fetchall()}
# Find gaps (dates with forecasts but no validation)
gap_dates = sorted(forecast_dates - validated_dates)
if not gap_dates:
logger.info("No validation gaps found", tenant_id=tenant_id)
return []
# Group consecutive dates into ranges
gaps = []
current_gap_start = gap_dates[0]
current_gap_end = gap_dates[0]
for i in range(1, len(gap_dates)):
if (gap_dates[i] - current_gap_end).days == 1:
# Consecutive date, extend current gap
current_gap_end = gap_dates[i]
else:
# Gap in dates, save current gap and start new one
gaps.append({
"start_date": current_gap_start,
"end_date": current_gap_end,
"days_count": (current_gap_end - current_gap_start).days + 1
})
current_gap_start = gap_dates[i]
current_gap_end = gap_dates[i]
# Don't forget the last gap
gaps.append({
"start_date": current_gap_start,
"end_date": current_gap_end,
"days_count": (current_gap_end - current_gap_start).days + 1
})
logger.info(
"Validation gaps detected",
tenant_id=tenant_id,
gaps_count=len(gaps),
total_days=len(gap_dates)
)
return gaps
except Exception as e:
logger.error(
"Failed to detect validation gaps",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to detect validation gaps: {str(e)}")
async def backfill_validation(
self,
tenant_id: uuid.UUID,
start_date: date,
end_date: date,
triggered_by: str = "manual",
sales_data_update_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""
Backfill validation for a historical date range
Args:
tenant_id: Tenant identifier
start_date: Start date for backfill
end_date: End date for backfill
triggered_by: How this backfill was triggered
sales_data_update_id: Optional link to sales data update record
Returns:
Backfill results with validation summary
"""
try:
logger.info(
"Starting validation backfill",
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat(),
triggered_by=triggered_by
)
# Convert dates to datetime
start_datetime = datetime.combine(start_date, datetime.min.time()).replace(tzinfo=timezone.utc)
end_datetime = datetime.combine(end_date, datetime.max.time()).replace(tzinfo=timezone.utc)
# Run validation for the date range
validation_result = await self.validation_service.validate_date_range(
tenant_id=tenant_id,
start_date=start_datetime,
end_date=end_datetime,
orchestration_run_id=None,
triggered_by=triggered_by
)
# Update sales data update record if provided
if sales_data_update_id:
await self._update_sales_data_record(
sales_data_update_id=sales_data_update_id,
validation_run_id=uuid.UUID(validation_result["validation_run_id"]),
status="completed" if validation_result["status"] == "completed" else "failed"
)
logger.info(
"Validation backfill completed",
tenant_id=tenant_id,
validation_run_id=validation_result.get("validation_run_id"),
forecasts_evaluated=validation_result.get("forecasts_evaluated")
)
return {
**validation_result,
"backfill_date_range": {
"start": start_date.isoformat(),
"end": end_date.isoformat()
}
}
except Exception as e:
logger.error(
"Validation backfill failed",
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat(),
error=str(e)
)
if sales_data_update_id:
await self._update_sales_data_record(
sales_data_update_id=sales_data_update_id,
validation_run_id=None,
status="failed",
error_message=str(e)
)
raise DatabaseError(f"Validation backfill failed: {str(e)}")
async def auto_backfill_gaps(
self,
tenant_id: uuid.UUID,
lookback_days: int = 90,
max_gaps_to_process: int = 10
) -> Dict[str, Any]:
"""
Automatically detect and backfill validation gaps
Args:
tenant_id: Tenant identifier
lookback_days: How far back to check
max_gaps_to_process: Maximum number of gaps to process in one run
Returns:
Summary of backfill operations
"""
try:
logger.info(
"Starting auto backfill",
tenant_id=tenant_id,
lookback_days=lookback_days
)
# Detect gaps
gaps = await self.detect_validation_gaps(tenant_id, lookback_days)
if not gaps:
return {
"gaps_found": 0,
"gaps_processed": 0,
"validations_completed": 0,
"message": "No validation gaps found"
}
# Limit number of gaps to process
gaps_to_process = gaps[:max_gaps_to_process]
results = []
for gap in gaps_to_process:
try:
result = await self.backfill_validation(
tenant_id=tenant_id,
start_date=gap["start_date"],
end_date=gap["end_date"],
triggered_by="auto_backfill"
)
results.append({
"gap": gap,
"result": result,
"status": "success"
})
except Exception as e:
logger.error(
"Failed to backfill gap",
gap=gap,
error=str(e)
)
results.append({
"gap": gap,
"error": str(e),
"status": "failed"
})
successful = sum(1 for r in results if r["status"] == "success")
logger.info(
"Auto backfill completed",
tenant_id=tenant_id,
gaps_found=len(gaps),
gaps_processed=len(results),
successful=successful
)
return {
"gaps_found": len(gaps),
"gaps_processed": len(results),
"validations_completed": successful,
"validations_failed": len(results) - successful,
"results": results
}
except Exception as e:
logger.error(
"Auto backfill failed",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Auto backfill failed: {str(e)}")
async def register_sales_data_update(
self,
tenant_id: uuid.UUID,
start_date: date,
end_date: date,
records_affected: int,
update_source: str = "import",
import_job_id: Optional[str] = None,
auto_trigger_validation: bool = True
) -> Dict[str, Any]:
"""
Register a sales data update and optionally trigger validation
Args:
tenant_id: Tenant identifier
start_date: Start date of updated data
end_date: End date of updated data
records_affected: Number of sales records affected
update_source: Source of update (import, manual, pos_sync)
import_job_id: Optional import job ID
auto_trigger_validation: Whether to automatically trigger validation
Returns:
Update record and validation result if triggered
"""
try:
# Create sales data update record
update_record = SalesDataUpdate(
tenant_id=tenant_id,
update_date_start=start_date,
update_date_end=end_date,
records_affected=records_affected,
update_source=update_source,
import_job_id=import_job_id,
requires_validation=auto_trigger_validation,
validation_status="pending" if auto_trigger_validation else "not_required"
)
self.db.add(update_record)
await self.db.flush()
logger.info(
"Registered sales data update",
tenant_id=tenant_id,
update_id=update_record.id,
date_range=f"{start_date} to {end_date}",
records_affected=records_affected
)
result = {
"update_id": str(update_record.id),
"update_record": update_record.to_dict(),
"validation_triggered": False
}
# Trigger validation if requested
if auto_trigger_validation:
try:
validation_result = await self.backfill_validation(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
triggered_by="sales_data_update",
sales_data_update_id=update_record.id
)
result["validation_triggered"] = True
result["validation_result"] = validation_result
logger.info(
"Validation triggered for sales data update",
update_id=update_record.id,
validation_run_id=validation_result.get("validation_run_id")
)
except Exception as e:
logger.error(
"Failed to trigger validation for sales data update",
update_id=update_record.id,
error=str(e)
)
update_record.validation_status = "failed"
update_record.validation_error = str(e)[:500]
await self.db.commit()
return result
except Exception as e:
logger.error(
"Failed to register sales data update",
tenant_id=tenant_id,
error=str(e)
)
await self.db.rollback()
raise DatabaseError(f"Failed to register sales data update: {str(e)}")
async def _update_sales_data_record(
self,
sales_data_update_id: uuid.UUID,
validation_run_id: Optional[uuid.UUID],
status: str,
error_message: Optional[str] = None
):
"""Update sales data update record with validation results"""
try:
query = select(SalesDataUpdate).where(SalesDataUpdate.id == sales_data_update_id)
result = await self.db.execute(query)
update_record = result.scalar_one_or_none()
if update_record:
update_record.validation_status = status
update_record.validation_run_id = validation_run_id
update_record.validated_at = datetime.now(timezone.utc)
if error_message:
update_record.validation_error = error_message[:500]
await self.db.commit()
except Exception as e:
logger.error(
"Failed to update sales data record",
sales_data_update_id=sales_data_update_id,
error=str(e)
)
async def get_pending_validations(
self,
tenant_id: uuid.UUID,
limit: int = 50
) -> List[SalesDataUpdate]:
"""Get pending sales data updates that need validation"""
try:
query = (
select(SalesDataUpdate)
.where(
and_(
SalesDataUpdate.tenant_id == tenant_id,
SalesDataUpdate.validation_status == "pending",
SalesDataUpdate.requires_validation == True
)
)
.order_by(SalesDataUpdate.created_at)
.limit(limit)
)
result = await self.db.execute(query)
return result.scalars().all()
except Exception as e:
logger.error(
"Failed to get pending validations",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to get pending validations: {str(e)}")

View File

@@ -0,0 +1,435 @@
# ================================================================
# services/forecasting/app/services/performance_monitoring_service.py
# ================================================================
"""
Performance Monitoring Service
Monitors forecast accuracy over time and triggers actions when
performance degrades below acceptable thresholds.
"""
from typing import Dict, Any, List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, func, desc
from datetime import datetime, timedelta, timezone
import structlog
import uuid
from app.models.validation_run import ValidationRun
from app.models.predictions import ModelPerformanceMetric
from app.models.forecasts import Forecast
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class PerformanceMonitoringService:
"""Service for monitoring forecast performance and triggering improvements"""
# Configurable thresholds
MAPE_WARNING_THRESHOLD = 20.0 # Warning if MAPE > 20%
MAPE_CRITICAL_THRESHOLD = 30.0 # Critical if MAPE > 30%
MAPE_TREND_THRESHOLD = 5.0 # Alert if MAPE increases by > 5% over period
MIN_SAMPLES_FOR_ALERT = 5 # Minimum validations before alerting
TREND_LOOKBACK_DAYS = 30 # Days to analyze for trends
def __init__(self, db_session: AsyncSession):
self.db = db_session
async def get_accuracy_summary(
self,
tenant_id: uuid.UUID,
days: int = 30
) -> Dict[str, Any]:
"""
Get accuracy summary for recent period
Args:
tenant_id: Tenant identifier
days: Number of days to analyze
Returns:
Summary with overall metrics and trends
"""
try:
start_date = datetime.now(timezone.utc) - timedelta(days=days)
# Get recent validation runs
query = (
select(ValidationRun)
.where(
and_(
ValidationRun.tenant_id == tenant_id,
ValidationRun.status == "completed",
ValidationRun.started_at >= start_date,
ValidationRun.forecasts_with_actuals > 0 # Only runs with actual data
)
)
.order_by(desc(ValidationRun.started_at))
)
result = await self.db.execute(query)
runs = result.scalars().all()
if not runs:
return {
"status": "no_data",
"message": f"No validation runs found in last {days} days",
"period_days": days
}
# Calculate summary statistics
total_forecasts = sum(r.total_forecasts_evaluated for r in runs)
total_with_actuals = sum(r.forecasts_with_actuals for r in runs)
mape_values = [r.overall_mape for r in runs if r.overall_mape is not None]
mae_values = [r.overall_mae for r in runs if r.overall_mae is not None]
rmse_values = [r.overall_rmse for r in runs if r.overall_rmse is not None]
avg_mape = sum(mape_values) / len(mape_values) if mape_values else None
avg_mae = sum(mae_values) / len(mae_values) if mae_values else None
avg_rmse = sum(rmse_values) / len(rmse_values) if rmse_values else None
# Determine health status
health_status = "healthy"
if avg_mape and avg_mape > self.MAPE_CRITICAL_THRESHOLD:
health_status = "critical"
elif avg_mape and avg_mape > self.MAPE_WARNING_THRESHOLD:
health_status = "warning"
return {
"status": "ok",
"period_days": days,
"validation_runs": len(runs),
"total_forecasts_evaluated": total_forecasts,
"total_forecasts_with_actuals": total_with_actuals,
"coverage_percentage": round(
(total_with_actuals / total_forecasts * 100) if total_forecasts > 0 else 0, 2
),
"average_metrics": {
"mape": round(avg_mape, 2) if avg_mape else None,
"mae": round(avg_mae, 2) if avg_mae else None,
"rmse": round(avg_rmse, 2) if avg_rmse else None,
"accuracy_percentage": round(100 - avg_mape, 2) if avg_mape else None
},
"health_status": health_status,
"thresholds": {
"warning": self.MAPE_WARNING_THRESHOLD,
"critical": self.MAPE_CRITICAL_THRESHOLD
}
}
except Exception as e:
logger.error(
"Failed to get accuracy summary",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to get accuracy summary: {str(e)}")
async def detect_performance_degradation(
self,
tenant_id: uuid.UUID,
lookback_days: int = 30
) -> Dict[str, Any]:
"""
Detect if forecast performance is degrading over time
Args:
tenant_id: Tenant identifier
lookback_days: Days to analyze for trends
Returns:
Degradation analysis with recommendations
"""
try:
logger.info(
"Detecting performance degradation",
tenant_id=tenant_id,
lookback_days=lookback_days
)
start_date = datetime.now(timezone.utc) - timedelta(days=lookback_days)
# Get validation runs ordered by time
query = (
select(ValidationRun)
.where(
and_(
ValidationRun.tenant_id == tenant_id,
ValidationRun.status == "completed",
ValidationRun.started_at >= start_date,
ValidationRun.forecasts_with_actuals > 0
)
)
.order_by(ValidationRun.started_at)
)
result = await self.db.execute(query)
runs = list(result.scalars().all())
if len(runs) < self.MIN_SAMPLES_FOR_ALERT:
return {
"status": "insufficient_data",
"message": f"Need at least {self.MIN_SAMPLES_FOR_ALERT} validation runs",
"runs_found": len(runs)
}
# Split into first half and second half
midpoint = len(runs) // 2
first_half = runs[:midpoint]
second_half = runs[midpoint:]
# Calculate average MAPE for each half
first_half_mape = sum(
r.overall_mape for r in first_half if r.overall_mape
) / len([r for r in first_half if r.overall_mape])
second_half_mape = sum(
r.overall_mape for r in second_half if r.overall_mape
) / len([r for r in second_half if r.overall_mape])
mape_change = second_half_mape - first_half_mape
mape_change_percentage = (mape_change / first_half_mape * 100) if first_half_mape > 0 else 0
# Determine if degradation is significant
is_degrading = mape_change > self.MAPE_TREND_THRESHOLD
severity = "none"
if is_degrading:
if mape_change > self.MAPE_TREND_THRESHOLD * 2:
severity = "high"
elif mape_change > self.MAPE_TREND_THRESHOLD:
severity = "medium"
# Get products with worst performance
poor_products = await self._identify_poor_performers(tenant_id, lookback_days)
result = {
"status": "analyzed",
"period_days": lookback_days,
"samples_analyzed": len(runs),
"is_degrading": is_degrading,
"severity": severity,
"metrics": {
"first_period_mape": round(first_half_mape, 2),
"second_period_mape": round(second_half_mape, 2),
"mape_change": round(mape_change, 2),
"mape_change_percentage": round(mape_change_percentage, 2)
},
"poor_performers": poor_products,
"recommendations": []
}
# Add recommendations
if is_degrading:
result["recommendations"].append({
"action": "retrain_models",
"priority": "high" if severity == "high" else "medium",
"reason": f"MAPE increased by {abs(mape_change):.1f}% over {lookback_days} days"
})
if poor_products:
result["recommendations"].append({
"action": "retrain_poor_performers",
"priority": "high",
"reason": f"{len(poor_products)} products with MAPE > {self.MAPE_CRITICAL_THRESHOLD}%",
"products": poor_products[:10] # Top 10 worst
})
logger.info(
"Performance degradation analysis complete",
tenant_id=tenant_id,
is_degrading=is_degrading,
severity=severity,
poor_performers=len(poor_products)
)
return result
except Exception as e:
logger.error(
"Failed to detect performance degradation",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to detect degradation: {str(e)}")
async def _identify_poor_performers(
self,
tenant_id: uuid.UUID,
lookback_days: int
) -> List[Dict[str, Any]]:
"""Identify products/locations with poor accuracy"""
try:
start_date = datetime.now(timezone.utc) - timedelta(days=lookback_days)
# Get recent performance metrics grouped by product
query = select(
ModelPerformanceMetric.inventory_product_id,
func.avg(ModelPerformanceMetric.mape).label('avg_mape'),
func.avg(ModelPerformanceMetric.mae).label('avg_mae'),
func.count(ModelPerformanceMetric.id).label('sample_count')
).where(
and_(
ModelPerformanceMetric.tenant_id == tenant_id,
ModelPerformanceMetric.created_at >= start_date,
ModelPerformanceMetric.mape.isnot(None)
)
).group_by(
ModelPerformanceMetric.inventory_product_id
).having(
func.avg(ModelPerformanceMetric.mape) > self.MAPE_CRITICAL_THRESHOLD
).order_by(
desc(func.avg(ModelPerformanceMetric.mape))
).limit(20)
result = await self.db.execute(query)
poor_performers = []
for row in result.fetchall():
poor_performers.append({
"inventory_product_id": str(row.inventory_product_id),
"avg_mape": round(row.avg_mape, 2),
"avg_mae": round(row.avg_mae, 2),
"sample_count": row.sample_count,
"requires_retraining": True
})
return poor_performers
except Exception as e:
logger.error("Failed to identify poor performers", error=str(e))
return []
async def check_model_age(
self,
tenant_id: uuid.UUID,
max_age_days: int = 30
) -> Dict[str, Any]:
"""
Check if models are outdated and need retraining
Args:
tenant_id: Tenant identifier
max_age_days: Maximum acceptable model age
Returns:
Analysis of model ages
"""
try:
# Get distinct models used in recent forecasts
cutoff_date = datetime.now(timezone.utc) - timedelta(days=7)
query = select(
Forecast.model_id,
Forecast.model_version,
Forecast.inventory_product_id,
func.max(Forecast.created_at).label('last_used'),
func.count(Forecast.id).label('forecast_count')
).where(
and_(
Forecast.tenant_id == tenant_id,
Forecast.created_at >= cutoff_date
)
).group_by(
Forecast.model_id,
Forecast.model_version,
Forecast.inventory_product_id
)
result = await self.db.execute(query)
models_info = []
outdated_count = 0
for row in result.fetchall():
# Check age against training service (would need to query training service)
# For now, assume models older than max_age_days need retraining
models_info.append({
"model_id": row.model_id,
"model_version": row.model_version,
"inventory_product_id": str(row.inventory_product_id),
"last_used": row.last_used.isoformat(),
"forecast_count": row.forecast_count
})
return {
"status": "analyzed",
"models_in_use": len(models_info),
"outdated_models": outdated_count,
"max_age_days": max_age_days,
"models": models_info[:20] # Top 20
}
except Exception as e:
logger.error("Failed to check model age", error=str(e))
raise DatabaseError(f"Failed to check model age: {str(e)}")
async def generate_performance_report(
self,
tenant_id: uuid.UUID,
days: int = 30
) -> Dict[str, Any]:
"""
Generate comprehensive performance report
Args:
tenant_id: Tenant identifier
days: Analysis period
Returns:
Complete performance report with recommendations
"""
try:
logger.info(
"Generating performance report",
tenant_id=tenant_id,
days=days
)
# Get all analyses
summary = await self.get_accuracy_summary(tenant_id, days)
degradation = await self.detect_performance_degradation(tenant_id, days)
model_age = await self.check_model_age(tenant_id)
# Compile recommendations
all_recommendations = []
if summary.get("health_status") == "critical":
all_recommendations.append({
"priority": "critical",
"action": "immediate_review",
"reason": f"Overall MAPE is {summary['average_metrics']['mape']}%",
"details": "Forecast accuracy is critically low"
})
all_recommendations.extend(degradation.get("recommendations", []))
report = {
"generated_at": datetime.now(timezone.utc).isoformat(),
"tenant_id": str(tenant_id),
"analysis_period_days": days,
"summary": summary,
"degradation_analysis": degradation,
"model_age_analysis": model_age,
"recommendations": all_recommendations,
"requires_action": len(all_recommendations) > 0
}
logger.info(
"Performance report generated",
tenant_id=tenant_id,
health_status=summary.get("health_status"),
recommendations_count=len(all_recommendations)
)
return report
except Exception as e:
logger.error(
"Failed to generate performance report",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to generate report: {str(e)}")

View File

@@ -0,0 +1,384 @@
# ================================================================
# services/forecasting/app/services/retraining_trigger_service.py
# ================================================================
"""
Retraining Trigger Service
Automatically triggers model retraining based on performance metrics,
accuracy degradation, or data availability.
"""
from typing import Dict, Any, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime, timezone
import structlog
import uuid
from app.services.performance_monitoring_service import PerformanceMonitoringService
from shared.clients.training_client import TrainingServiceClient
from shared.config.base import BaseServiceSettings
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class RetrainingTriggerService:
"""Service for triggering automatic model retraining"""
def __init__(self, db_session: AsyncSession):
self.db = db_session
self.performance_service = PerformanceMonitoringService(db_session)
# Initialize training client
config = BaseServiceSettings()
self.training_client = TrainingServiceClient(config, calling_service_name="forecasting")
async def evaluate_and_trigger_retraining(
self,
tenant_id: uuid.UUID,
auto_trigger: bool = True
) -> Dict[str, Any]:
"""
Evaluate performance and trigger retraining if needed
Args:
tenant_id: Tenant identifier
auto_trigger: Whether to automatically trigger retraining
Returns:
Evaluation results and retraining actions taken
"""
try:
logger.info(
"Evaluating retraining needs",
tenant_id=tenant_id,
auto_trigger=auto_trigger
)
# Generate performance report
report = await self.performance_service.generate_performance_report(
tenant_id=tenant_id,
days=30
)
if not report.get("requires_action"):
logger.info(
"No retraining required",
tenant_id=tenant_id,
health_status=report["summary"].get("health_status")
)
return {
"status": "no_action_needed",
"tenant_id": str(tenant_id),
"health_status": report["summary"].get("health_status"),
"report": report
}
# Extract products that need retraining
products_to_retrain = []
recommendations = report.get("recommendations", [])
for rec in recommendations:
if rec.get("action") == "retrain_poor_performers":
products_to_retrain.extend(rec.get("products", []))
if not products_to_retrain and auto_trigger:
# If degradation detected but no specific products, consider retraining all
degradation = report.get("degradation_analysis", {})
if degradation.get("is_degrading") and degradation.get("severity") in ["high", "medium"]:
logger.info(
"General degradation detected, considering full retraining",
tenant_id=tenant_id,
severity=degradation.get("severity")
)
retraining_results = []
if auto_trigger and products_to_retrain:
# Trigger retraining for poor performers
for product in products_to_retrain:
try:
result = await self._trigger_product_retraining(
tenant_id=tenant_id,
inventory_product_id=uuid.UUID(product["inventory_product_id"]),
reason=f"MAPE {product['avg_mape']}% exceeds threshold",
priority="high"
)
retraining_results.append(result)
except Exception as e:
logger.error(
"Failed to trigger retraining for product",
product_id=product["inventory_product_id"],
error=str(e)
)
retraining_results.append({
"product_id": product["inventory_product_id"],
"status": "failed",
"error": str(e)
})
logger.info(
"Retraining evaluation complete",
tenant_id=tenant_id,
products_evaluated=len(products_to_retrain),
retraining_triggered=len(retraining_results)
)
return {
"status": "evaluated",
"tenant_id": str(tenant_id),
"requires_action": report.get("requires_action"),
"products_needing_retraining": len(products_to_retrain),
"retraining_triggered": len(retraining_results),
"auto_trigger_enabled": auto_trigger,
"retraining_results": retraining_results,
"performance_report": report
}
except Exception as e:
logger.error(
"Failed to evaluate and trigger retraining",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to evaluate retraining: {str(e)}")
async def _trigger_product_retraining(
self,
tenant_id: uuid.UUID,
inventory_product_id: uuid.UUID,
reason: str,
priority: str = "normal"
) -> Dict[str, Any]:
"""
Trigger retraining for a specific product
Args:
tenant_id: Tenant identifier
inventory_product_id: Product to retrain
reason: Reason for retraining
priority: Priority level (low, normal, high)
Returns:
Retraining trigger result
"""
try:
logger.info(
"Triggering product retraining",
tenant_id=tenant_id,
product_id=inventory_product_id,
reason=reason,
priority=priority
)
# Call training service to trigger retraining
result = await self.training_client.trigger_retrain(
tenant_id=str(tenant_id),
inventory_product_id=str(inventory_product_id),
reason=reason,
priority=priority
)
if result:
logger.info(
"Retraining triggered successfully",
tenant_id=tenant_id,
product_id=inventory_product_id,
training_job_id=result.get("training_job_id")
)
return {
"status": "triggered",
"product_id": str(inventory_product_id),
"training_job_id": result.get("training_job_id"),
"reason": reason,
"priority": priority,
"triggered_at": datetime.now(timezone.utc).isoformat()
}
else:
logger.warning(
"Retraining trigger returned no result",
tenant_id=tenant_id,
product_id=inventory_product_id
)
return {
"status": "no_response",
"product_id": str(inventory_product_id),
"reason": reason
}
except Exception as e:
logger.error(
"Failed to trigger product retraining",
tenant_id=tenant_id,
product_id=inventory_product_id,
error=str(e)
)
return {
"status": "failed",
"product_id": str(inventory_product_id),
"error": str(e),
"reason": reason
}
async def trigger_bulk_retraining(
self,
tenant_id: uuid.UUID,
product_ids: List[uuid.UUID],
reason: str = "Bulk retraining requested"
) -> Dict[str, Any]:
"""
Trigger retraining for multiple products
Args:
tenant_id: Tenant identifier
product_ids: List of products to retrain
reason: Reason for bulk retraining
Returns:
Bulk retraining results
"""
try:
logger.info(
"Triggering bulk retraining",
tenant_id=tenant_id,
product_count=len(product_ids)
)
results = []
for product_id in product_ids:
result = await self._trigger_product_retraining(
tenant_id=tenant_id,
inventory_product_id=product_id,
reason=reason,
priority="normal"
)
results.append(result)
successful = sum(1 for r in results if r["status"] == "triggered")
logger.info(
"Bulk retraining completed",
tenant_id=tenant_id,
total=len(product_ids),
successful=successful,
failed=len(product_ids) - successful
)
return {
"status": "completed",
"tenant_id": str(tenant_id),
"total_products": len(product_ids),
"successful": successful,
"failed": len(product_ids) - successful,
"results": results
}
except Exception as e:
logger.error(
"Bulk retraining failed",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Bulk retraining failed: {str(e)}")
async def check_and_trigger_scheduled_retraining(
self,
tenant_id: uuid.UUID,
max_model_age_days: int = 30
) -> Dict[str, Any]:
"""
Check model ages and trigger retraining for outdated models
Args:
tenant_id: Tenant identifier
max_model_age_days: Maximum acceptable model age
Returns:
Scheduled retraining results
"""
try:
logger.info(
"Checking for scheduled retraining needs",
tenant_id=tenant_id,
max_model_age_days=max_model_age_days
)
# Get model age analysis
model_age_analysis = await self.performance_service.check_model_age(
tenant_id=tenant_id,
max_age_days=max_model_age_days
)
outdated_count = model_age_analysis.get("outdated_models", 0)
if outdated_count == 0:
logger.info(
"No outdated models found",
tenant_id=tenant_id
)
return {
"status": "no_action_needed",
"tenant_id": str(tenant_id),
"outdated_models": 0
}
# TODO: Trigger retraining for outdated models
# Would need to get list of outdated products from training service
return {
"status": "analyzed",
"tenant_id": str(tenant_id),
"outdated_models": outdated_count,
"message": "Scheduled retraining analysis complete"
}
except Exception as e:
logger.error(
"Scheduled retraining check failed",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Scheduled retraining check failed: {str(e)}")
async def get_retraining_recommendations(
self,
tenant_id: uuid.UUID
) -> Dict[str, Any]:
"""
Get retraining recommendations without triggering
Args:
tenant_id: Tenant identifier
Returns:
Recommendations for manual review
"""
try:
# Evaluate without auto-triggering
result = await self.evaluate_and_trigger_retraining(
tenant_id=tenant_id,
auto_trigger=False
)
# Extract just the recommendations
report = result.get("performance_report", {})
recommendations = report.get("recommendations", [])
return {
"tenant_id": str(tenant_id),
"generated_at": datetime.now(timezone.utc).isoformat(),
"requires_action": result.get("requires_action", False),
"recommendations": recommendations,
"summary": report.get("summary", {}),
"degradation_detected": report.get("degradation_analysis", {}).get("is_degrading", False)
}
except Exception as e:
logger.error(
"Failed to get retraining recommendations",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to get recommendations: {str(e)}")

View File

@@ -0,0 +1,97 @@
# ================================================================
# services/forecasting/app/services/sales_client.py
# ================================================================
"""
Sales Client for Forecasting Service
Wrapper around shared sales client with forecasting-specific methods
"""
from typing import List, Dict, Any
from datetime import datetime
import structlog
import uuid
from shared.clients.sales_client import SalesServiceClient
from shared.config.base import BaseServiceSettings
logger = structlog.get_logger()
class SalesClient:
"""Client for fetching sales data from sales service"""
def __init__(self):
"""Initialize sales client"""
# Load configuration
config = BaseServiceSettings()
self.client = SalesServiceClient(config, calling_service_name="forecasting")
async def get_sales_by_date_range(
self,
tenant_id: uuid.UUID,
start_date: datetime,
end_date: datetime,
product_id: uuid.UUID = None
) -> List[Dict[str, Any]]:
"""
Get sales data for a date range
Args:
tenant_id: Tenant identifier
start_date: Start of date range
end_date: End of date range
product_id: Optional product filter
Returns:
List of sales records
"""
try:
# Convert datetime to ISO format strings
start_date_str = start_date.isoformat() if start_date else None
end_date_str = end_date.isoformat() if end_date else None
product_id_str = str(product_id) if product_id else None
# Use the paginated method to get all sales data
sales_data = await self.client.get_all_sales_data(
tenant_id=str(tenant_id),
start_date=start_date_str,
end_date=end_date_str,
product_id=product_id_str,
aggregation="none", # Get raw data without aggregation
page_size=1000,
max_pages=100
)
if not sales_data:
logger.info(
"No sales data found for date range",
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat()
)
return []
logger.info(
"Retrieved sales data",
tenant_id=tenant_id,
records_count=len(sales_data),
start_date=start_date.isoformat(),
end_date=end_date.isoformat()
)
return sales_data
except Exception as e:
logger.error(
"Failed to fetch sales data",
tenant_id=tenant_id,
error=str(e),
error_type=type(e).__name__
)
# Return empty list instead of raising to allow validation to continue
return []
async def close(self):
"""Close the client connection"""
if hasattr(self.client, 'close'):
await self.client.close()

View File

@@ -0,0 +1,586 @@
# ================================================================
# services/forecasting/app/services/validation_service.py
# ================================================================
"""
Forecast Validation Service
Compares historical forecasts with actual sales data to:
1. Calculate accuracy metrics (MAE, MAPE, RMSE, R², accuracy percentage)
2. Store performance metrics in the database
3. Track validation runs for audit purposes
4. Enable continuous model improvement
"""
from typing import Dict, Any, List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, func, Date
from datetime import datetime, timedelta, timezone
import structlog
import math
import uuid
from app.models.forecasts import Forecast
from app.models.predictions import ModelPerformanceMetric
from app.models.validation_run import ValidationRun
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class ValidationService:
"""Service for validating forecasts against actual sales data"""
def __init__(self, db_session: AsyncSession):
self.db = db_session
async def validate_date_range(
self,
tenant_id: uuid.UUID,
start_date: datetime,
end_date: datetime,
orchestration_run_id: Optional[uuid.UUID] = None,
triggered_by: str = "manual"
) -> Dict[str, Any]:
"""
Validate forecasts against actual sales for a date range
Args:
tenant_id: Tenant identifier
start_date: Start of validation period
end_date: End of validation period
orchestration_run_id: Optional link to orchestration run
triggered_by: How this validation was triggered (manual, orchestrator, scheduled)
Returns:
Dictionary with validation results and metrics
"""
validation_run = None
try:
# Create validation run record
validation_run = ValidationRun(
tenant_id=tenant_id,
orchestration_run_id=orchestration_run_id,
validation_start_date=start_date,
validation_end_date=end_date,
status="running",
triggered_by=triggered_by,
execution_mode="batch"
)
self.db.add(validation_run)
await self.db.flush()
logger.info(
"Starting forecast validation",
validation_run_id=validation_run.id,
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat()
)
# Fetch forecasts with matching sales data
forecasts_with_sales = await self._fetch_forecasts_with_sales(
tenant_id, start_date, end_date
)
if not forecasts_with_sales:
logger.warning(
"No forecasts with matching sales data found",
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat()
)
validation_run.status = "completed"
validation_run.completed_at = datetime.now(timezone.utc)
validation_run.duration_seconds = (
validation_run.completed_at - validation_run.started_at
).total_seconds()
await self.db.commit()
return {
"validation_run_id": str(validation_run.id),
"status": "completed",
"message": "No forecasts with matching sales data found",
"forecasts_evaluated": 0,
"metrics_created": 0
}
# Calculate metrics and create performance records
metrics_results = await self._calculate_and_store_metrics(
forecasts_with_sales, validation_run.id
)
# Update validation run with results
validation_run.total_forecasts_evaluated = metrics_results["total_evaluated"]
validation_run.forecasts_with_actuals = metrics_results["with_actuals"]
validation_run.forecasts_without_actuals = metrics_results["without_actuals"]
validation_run.overall_mae = metrics_results["overall_mae"]
validation_run.overall_mape = metrics_results["overall_mape"]
validation_run.overall_rmse = metrics_results["overall_rmse"]
validation_run.overall_r2_score = metrics_results["overall_r2_score"]
validation_run.overall_accuracy_percentage = metrics_results["overall_accuracy_percentage"]
validation_run.total_predicted_demand = metrics_results["total_predicted"]
validation_run.total_actual_demand = metrics_results["total_actual"]
validation_run.metrics_by_product = metrics_results["metrics_by_product"]
validation_run.metrics_by_location = metrics_results["metrics_by_location"]
validation_run.metrics_records_created = metrics_results["metrics_created"]
validation_run.status = "completed"
validation_run.completed_at = datetime.now(timezone.utc)
validation_run.duration_seconds = (
validation_run.completed_at - validation_run.started_at
).total_seconds()
await self.db.commit()
logger.info(
"Forecast validation completed successfully",
validation_run_id=validation_run.id,
forecasts_evaluated=validation_run.total_forecasts_evaluated,
metrics_created=validation_run.metrics_records_created,
overall_mape=validation_run.overall_mape,
duration_seconds=validation_run.duration_seconds
)
# Extract poor accuracy products (MAPE > 30%)
poor_accuracy_products = []
if validation_run.metrics_by_product:
for product_id, product_metrics in validation_run.metrics_by_product.items():
if product_metrics.get("mape", 0) > 30:
poor_accuracy_products.append({
"product_id": product_id,
"mape": product_metrics.get("mape"),
"mae": product_metrics.get("mae"),
"accuracy_percentage": product_metrics.get("accuracy_percentage")
})
return {
"validation_run_id": str(validation_run.id),
"status": "completed",
"forecasts_evaluated": validation_run.total_forecasts_evaluated,
"forecasts_with_actuals": validation_run.forecasts_with_actuals,
"forecasts_without_actuals": validation_run.forecasts_without_actuals,
"metrics_created": validation_run.metrics_records_created,
"overall_metrics": {
"mae": validation_run.overall_mae,
"mape": validation_run.overall_mape,
"rmse": validation_run.overall_rmse,
"r2_score": validation_run.overall_r2_score,
"accuracy_percentage": validation_run.overall_accuracy_percentage
},
"total_predicted_demand": validation_run.total_predicted_demand,
"total_actual_demand": validation_run.total_actual_demand,
"duration_seconds": validation_run.duration_seconds,
"poor_accuracy_products": poor_accuracy_products,
"metrics_by_product": validation_run.metrics_by_product,
"metrics_by_location": validation_run.metrics_by_location
}
except Exception as e:
logger.error(
"Forecast validation failed",
tenant_id=tenant_id,
error=str(e),
error_type=type(e).__name__
)
if validation_run:
validation_run.status = "failed"
validation_run.error_message = str(e)
validation_run.error_details = {"error_type": type(e).__name__}
validation_run.completed_at = datetime.now(timezone.utc)
validation_run.duration_seconds = (
validation_run.completed_at - validation_run.started_at
).total_seconds()
await self.db.commit()
raise DatabaseError(f"Forecast validation failed: {str(e)}")
async def validate_yesterday(
self,
tenant_id: uuid.UUID,
orchestration_run_id: Optional[uuid.UUID] = None,
triggered_by: str = "orchestrator"
) -> Dict[str, Any]:
"""
Convenience method to validate yesterday's forecasts
Args:
tenant_id: Tenant identifier
orchestration_run_id: Optional link to orchestration run
triggered_by: How this validation was triggered
Returns:
Dictionary with validation results
"""
yesterday = datetime.now(timezone.utc) - timedelta(days=1)
start_date = yesterday.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = yesterday.replace(hour=23, minute=59, second=59, microsecond=999999)
return await self.validate_date_range(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
orchestration_run_id=orchestration_run_id,
triggered_by=triggered_by
)
async def _fetch_forecasts_with_sales(
self,
tenant_id: uuid.UUID,
start_date: datetime,
end_date: datetime
) -> List[Dict[str, Any]]:
"""
Fetch forecasts with their corresponding actual sales data
Returns list of dictionaries containing forecast and sales data
"""
try:
# Import here to avoid circular dependency
from app.services.sales_client import SalesClient
# Query to get all forecasts in the date range
query = select(Forecast).where(
and_(
Forecast.tenant_id == tenant_id,
func.cast(Forecast.forecast_date, Date) >= start_date.date(),
func.cast(Forecast.forecast_date, Date) <= end_date.date()
)
).order_by(Forecast.forecast_date, Forecast.inventory_product_id)
result = await self.db.execute(query)
forecasts = result.scalars().all()
if not forecasts:
return []
# Fetch actual sales data from sales service
sales_client = SalesClient()
sales_data = await sales_client.get_sales_by_date_range(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
# Create lookup dict: (product_id, date) -> sales quantity
sales_lookup = {}
for sale in sales_data:
sale_date = sale['date']
if isinstance(sale_date, str):
sale_date = datetime.fromisoformat(sale_date.replace('Z', '+00:00'))
key = (str(sale['inventory_product_id']), sale_date.date())
# Sum quantities if multiple sales records for same product/date
if key in sales_lookup:
sales_lookup[key]['quantity_sold'] += sale['quantity_sold']
else:
sales_lookup[key] = sale
# Match forecasts with sales data
forecasts_with_sales = []
for forecast in forecasts:
forecast_date = forecast.forecast_date.date() if hasattr(forecast.forecast_date, 'date') else forecast.forecast_date
key = (str(forecast.inventory_product_id), forecast_date)
sales_record = sales_lookup.get(key)
forecasts_with_sales.append({
"forecast_id": forecast.id,
"tenant_id": forecast.tenant_id,
"inventory_product_id": forecast.inventory_product_id,
"product_name": forecast.product_name,
"location": forecast.location,
"forecast_date": forecast.forecast_date,
"predicted_demand": forecast.predicted_demand,
"confidence_lower": forecast.confidence_lower,
"confidence_upper": forecast.confidence_upper,
"model_id": forecast.model_id,
"model_version": forecast.model_version,
"algorithm": forecast.algorithm,
"created_at": forecast.created_at,
"actual_sales": sales_record['quantity_sold'] if sales_record else None,
"has_actual_data": sales_record is not None
})
return forecasts_with_sales
except Exception as e:
logger.error(
"Failed to fetch forecasts with sales data",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to fetch forecast and sales data: {str(e)}")
async def _calculate_and_store_metrics(
self,
forecasts_with_sales: List[Dict[str, Any]],
validation_run_id: uuid.UUID
) -> Dict[str, Any]:
"""
Calculate accuracy metrics and store them in the database
Returns summary of metrics calculated
"""
# Separate forecasts with and without actual data
forecasts_with_actuals = [f for f in forecasts_with_sales if f["has_actual_data"]]
forecasts_without_actuals = [f for f in forecasts_with_sales if not f["has_actual_data"]]
if not forecasts_with_actuals:
return {
"total_evaluated": len(forecasts_with_sales),
"with_actuals": 0,
"without_actuals": len(forecasts_without_actuals),
"metrics_created": 0,
"overall_mae": None,
"overall_mape": None,
"overall_rmse": None,
"overall_r2_score": None,
"overall_accuracy_percentage": None,
"total_predicted": 0.0,
"total_actual": 0.0,
"metrics_by_product": {},
"metrics_by_location": {}
}
# Calculate individual metrics and prepare for bulk insert
performance_metrics = []
errors = []
for forecast in forecasts_with_actuals:
predicted = forecast["predicted_demand"]
actual = forecast["actual_sales"]
# Calculate error
error = abs(predicted - actual)
errors.append(error)
# Calculate percentage error (avoid division by zero)
percentage_error = (error / actual * 100) if actual > 0 else 0.0
# Calculate individual metrics
mae = error
mape = percentage_error
rmse = error # Will be squared and averaged later
# Calculate accuracy percentage (100% - MAPE, capped at 0)
accuracy_percentage = max(0.0, 100.0 - mape)
# Create performance metric record
metric = ModelPerformanceMetric(
model_id=uuid.UUID(forecast["model_id"]) if isinstance(forecast["model_id"], str) else forecast["model_id"],
tenant_id=forecast["tenant_id"],
inventory_product_id=forecast["inventory_product_id"],
mae=mae,
mape=mape,
rmse=rmse,
accuracy_score=accuracy_percentage / 100.0, # Store as 0-1 scale
evaluation_date=forecast["forecast_date"],
evaluation_period_start=forecast["forecast_date"],
evaluation_period_end=forecast["forecast_date"],
sample_size=1
)
performance_metrics.append(metric)
# Bulk insert all performance metrics
if performance_metrics:
self.db.add_all(performance_metrics)
await self.db.flush()
# Calculate overall metrics
overall_metrics = self._calculate_overall_metrics(forecasts_with_actuals)
# Calculate metrics by product
metrics_by_product = self._calculate_metrics_by_dimension(
forecasts_with_actuals, "inventory_product_id"
)
# Calculate metrics by location
metrics_by_location = self._calculate_metrics_by_dimension(
forecasts_with_actuals, "location"
)
return {
"total_evaluated": len(forecasts_with_sales),
"with_actuals": len(forecasts_with_actuals),
"without_actuals": len(forecasts_without_actuals),
"metrics_created": len(performance_metrics),
"overall_mae": overall_metrics["mae"],
"overall_mape": overall_metrics["mape"],
"overall_rmse": overall_metrics["rmse"],
"overall_r2_score": overall_metrics["r2_score"],
"overall_accuracy_percentage": overall_metrics["accuracy_percentage"],
"total_predicted": overall_metrics["total_predicted"],
"total_actual": overall_metrics["total_actual"],
"metrics_by_product": metrics_by_product,
"metrics_by_location": metrics_by_location
}
def _calculate_overall_metrics(self, forecasts: List[Dict[str, Any]]) -> Dict[str, float]:
"""Calculate aggregated metrics across all forecasts"""
if not forecasts:
return {
"mae": None, "mape": None, "rmse": None,
"r2_score": None, "accuracy_percentage": None,
"total_predicted": 0.0, "total_actual": 0.0
}
predicted_values = [f["predicted_demand"] for f in forecasts]
actual_values = [f["actual_sales"] for f in forecasts]
# MAE: Mean Absolute Error
mae = sum(abs(p - a) for p, a in zip(predicted_values, actual_values)) / len(forecasts)
# MAPE: Mean Absolute Percentage Error (handle division by zero)
mape_values = [
abs(p - a) / a * 100 if a > 0 else 0.0
for p, a in zip(predicted_values, actual_values)
]
mape = sum(mape_values) / len(mape_values)
# RMSE: Root Mean Square Error
squared_errors = [(p - a) ** 2 for p, a in zip(predicted_values, actual_values)]
rmse = math.sqrt(sum(squared_errors) / len(squared_errors))
# R² Score (coefficient of determination)
mean_actual = sum(actual_values) / len(actual_values)
ss_total = sum((a - mean_actual) ** 2 for a in actual_values)
ss_residual = sum((a - p) ** 2 for a, p in zip(actual_values, predicted_values))
r2_score = 1 - (ss_residual / ss_total) if ss_total > 0 else 0.0
# Accuracy percentage (100% - MAPE)
accuracy_percentage = max(0.0, 100.0 - mape)
return {
"mae": round(mae, 2),
"mape": round(mape, 2),
"rmse": round(rmse, 2),
"r2_score": round(r2_score, 4),
"accuracy_percentage": round(accuracy_percentage, 2),
"total_predicted": round(sum(predicted_values), 2),
"total_actual": round(sum(actual_values), 2)
}
def _calculate_metrics_by_dimension(
self,
forecasts: List[Dict[str, Any]],
dimension_key: str
) -> Dict[str, Dict[str, float]]:
"""Calculate metrics grouped by a dimension (product_id or location)"""
dimension_groups = {}
# Group forecasts by dimension
for forecast in forecasts:
key = str(forecast[dimension_key])
if key not in dimension_groups:
dimension_groups[key] = []
dimension_groups[key].append(forecast)
# Calculate metrics for each group
metrics_by_dimension = {}
for key, group_forecasts in dimension_groups.items():
metrics = self._calculate_overall_metrics(group_forecasts)
metrics_by_dimension[key] = {
"count": len(group_forecasts),
"mae": metrics["mae"],
"mape": metrics["mape"],
"rmse": metrics["rmse"],
"accuracy_percentage": metrics["accuracy_percentage"],
"total_predicted": metrics["total_predicted"],
"total_actual": metrics["total_actual"]
}
return metrics_by_dimension
async def get_validation_run(self, validation_run_id: uuid.UUID) -> Optional[ValidationRun]:
"""Get a validation run by ID"""
try:
query = select(ValidationRun).where(ValidationRun.id == validation_run_id)
result = await self.db.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error("Failed to get validation run", validation_run_id=validation_run_id, error=str(e))
raise DatabaseError(f"Failed to get validation run: {str(e)}")
async def get_validation_runs_by_tenant(
self,
tenant_id: uuid.UUID,
limit: int = 50,
skip: int = 0
) -> List[ValidationRun]:
"""Get validation runs for a tenant"""
try:
query = (
select(ValidationRun)
.where(ValidationRun.tenant_id == tenant_id)
.order_by(ValidationRun.created_at.desc())
.limit(limit)
.offset(skip)
)
result = await self.db.execute(query)
return result.scalars().all()
except Exception as e:
logger.error("Failed to get validation runs", tenant_id=tenant_id, error=str(e))
raise DatabaseError(f"Failed to get validation runs: {str(e)}")
async def get_accuracy_trends(
self,
tenant_id: uuid.UUID,
days: int = 30
) -> Dict[str, Any]:
"""Get accuracy trends over time"""
try:
start_date = datetime.now(timezone.utc) - timedelta(days=days)
query = (
select(ValidationRun)
.where(
and_(
ValidationRun.tenant_id == tenant_id,
ValidationRun.status == "completed",
ValidationRun.created_at >= start_date
)
)
.order_by(ValidationRun.created_at)
)
result = await self.db.execute(query)
runs = result.scalars().all()
if not runs:
return {
"period_days": days,
"total_runs": 0,
"trends": []
}
trends = [
{
"date": run.validation_start_date.isoformat(),
"mae": run.overall_mae,
"mape": run.overall_mape,
"rmse": run.overall_rmse,
"accuracy_percentage": run.overall_accuracy_percentage,
"forecasts_evaluated": run.total_forecasts_evaluated,
"forecasts_with_actuals": run.forecasts_with_actuals
}
for run in runs
]
# Calculate averages
valid_runs = [r for r in runs if r.overall_mape is not None]
avg_mape = sum(r.overall_mape for r in valid_runs) / len(valid_runs) if valid_runs else None
avg_accuracy = sum(r.overall_accuracy_percentage for r in valid_runs) / len(valid_runs) if valid_runs else None
return {
"period_days": days,
"total_runs": len(runs),
"average_mape": round(avg_mape, 2) if avg_mape else None,
"average_accuracy": round(avg_accuracy, 2) if avg_accuracy else None,
"trends": trends
}
except Exception as e:
logger.error("Failed to get accuracy trends", tenant_id=tenant_id, error=str(e))
raise DatabaseError(f"Failed to get accuracy trends: {str(e)}")