587 lines
24 KiB
Python
587 lines
24 KiB
Python
# ================================================================
|
|
# services/forecasting/app/services/validation_service.py
|
|
# ================================================================
|
|
"""
|
|
Forecast Validation Service
|
|
|
|
Compares historical forecasts with actual sales data to:
|
|
1. Calculate accuracy metrics (MAE, MAPE, RMSE, R², accuracy percentage)
|
|
2. Store performance metrics in the database
|
|
3. Track validation runs for audit purposes
|
|
4. Enable continuous model improvement
|
|
"""
|
|
|
|
from typing import Dict, Any, List, Optional, Tuple
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_, func, Date
|
|
from datetime import datetime, timedelta, timezone
|
|
import structlog
|
|
import math
|
|
import uuid
|
|
|
|
from app.models.forecasts import Forecast
|
|
from app.models.predictions import ModelPerformanceMetric
|
|
from app.models.validation_run import ValidationRun
|
|
from shared.database.exceptions import DatabaseError, ValidationError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class ValidationService:
|
|
"""Service for validating forecasts against actual sales data"""
|
|
|
|
def __init__(self, db_session: AsyncSession):
|
|
self.db = db_session
|
|
|
|
async def validate_date_range(
|
|
self,
|
|
tenant_id: uuid.UUID,
|
|
start_date: datetime,
|
|
end_date: datetime,
|
|
orchestration_run_id: Optional[uuid.UUID] = None,
|
|
triggered_by: str = "manual"
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Validate forecasts against actual sales for a date range
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
start_date: Start of validation period
|
|
end_date: End of validation period
|
|
orchestration_run_id: Optional link to orchestration run
|
|
triggered_by: How this validation was triggered (manual, orchestrator, scheduled)
|
|
|
|
Returns:
|
|
Dictionary with validation results and metrics
|
|
"""
|
|
validation_run = None
|
|
|
|
try:
|
|
# Create validation run record
|
|
validation_run = ValidationRun(
|
|
tenant_id=tenant_id,
|
|
orchestration_run_id=orchestration_run_id,
|
|
validation_start_date=start_date,
|
|
validation_end_date=end_date,
|
|
status="running",
|
|
triggered_by=triggered_by,
|
|
execution_mode="batch"
|
|
)
|
|
self.db.add(validation_run)
|
|
await self.db.flush()
|
|
|
|
logger.info(
|
|
"Starting forecast validation",
|
|
validation_run_id=validation_run.id,
|
|
tenant_id=tenant_id,
|
|
start_date=start_date.isoformat(),
|
|
end_date=end_date.isoformat()
|
|
)
|
|
|
|
# Fetch forecasts with matching sales data
|
|
forecasts_with_sales = await self._fetch_forecasts_with_sales(
|
|
tenant_id, start_date, end_date
|
|
)
|
|
|
|
if not forecasts_with_sales:
|
|
logger.warning(
|
|
"No forecasts with matching sales data found",
|
|
tenant_id=tenant_id,
|
|
start_date=start_date.isoformat(),
|
|
end_date=end_date.isoformat()
|
|
)
|
|
validation_run.status = "completed"
|
|
validation_run.completed_at = datetime.now(timezone.utc)
|
|
validation_run.duration_seconds = (
|
|
validation_run.completed_at - validation_run.started_at
|
|
).total_seconds()
|
|
await self.db.commit()
|
|
|
|
return {
|
|
"validation_run_id": str(validation_run.id),
|
|
"status": "completed",
|
|
"message": "No forecasts with matching sales data found",
|
|
"forecasts_evaluated": 0,
|
|
"metrics_created": 0
|
|
}
|
|
|
|
# Calculate metrics and create performance records
|
|
metrics_results = await self._calculate_and_store_metrics(
|
|
forecasts_with_sales, validation_run.id
|
|
)
|
|
|
|
# Update validation run with results
|
|
validation_run.total_forecasts_evaluated = metrics_results["total_evaluated"]
|
|
validation_run.forecasts_with_actuals = metrics_results["with_actuals"]
|
|
validation_run.forecasts_without_actuals = metrics_results["without_actuals"]
|
|
validation_run.overall_mae = metrics_results["overall_mae"]
|
|
validation_run.overall_mape = metrics_results["overall_mape"]
|
|
validation_run.overall_rmse = metrics_results["overall_rmse"]
|
|
validation_run.overall_r2_score = metrics_results["overall_r2_score"]
|
|
validation_run.overall_accuracy_percentage = metrics_results["overall_accuracy_percentage"]
|
|
validation_run.total_predicted_demand = metrics_results["total_predicted"]
|
|
validation_run.total_actual_demand = metrics_results["total_actual"]
|
|
validation_run.metrics_by_product = metrics_results["metrics_by_product"]
|
|
validation_run.metrics_by_location = metrics_results["metrics_by_location"]
|
|
validation_run.metrics_records_created = metrics_results["metrics_created"]
|
|
validation_run.status = "completed"
|
|
validation_run.completed_at = datetime.now(timezone.utc)
|
|
validation_run.duration_seconds = (
|
|
validation_run.completed_at - validation_run.started_at
|
|
).total_seconds()
|
|
|
|
await self.db.commit()
|
|
|
|
logger.info(
|
|
"Forecast validation completed successfully",
|
|
validation_run_id=validation_run.id,
|
|
forecasts_evaluated=validation_run.total_forecasts_evaluated,
|
|
metrics_created=validation_run.metrics_records_created,
|
|
overall_mape=validation_run.overall_mape,
|
|
duration_seconds=validation_run.duration_seconds
|
|
)
|
|
|
|
# Extract poor accuracy products (MAPE > 30%)
|
|
poor_accuracy_products = []
|
|
if validation_run.metrics_by_product:
|
|
for product_id, product_metrics in validation_run.metrics_by_product.items():
|
|
if product_metrics.get("mape", 0) > 30:
|
|
poor_accuracy_products.append({
|
|
"product_id": product_id,
|
|
"mape": product_metrics.get("mape"),
|
|
"mae": product_metrics.get("mae"),
|
|
"accuracy_percentage": product_metrics.get("accuracy_percentage")
|
|
})
|
|
|
|
return {
|
|
"validation_run_id": str(validation_run.id),
|
|
"status": "completed",
|
|
"forecasts_evaluated": validation_run.total_forecasts_evaluated,
|
|
"forecasts_with_actuals": validation_run.forecasts_with_actuals,
|
|
"forecasts_without_actuals": validation_run.forecasts_without_actuals,
|
|
"metrics_created": validation_run.metrics_records_created,
|
|
"overall_metrics": {
|
|
"mae": validation_run.overall_mae,
|
|
"mape": validation_run.overall_mape,
|
|
"rmse": validation_run.overall_rmse,
|
|
"r2_score": validation_run.overall_r2_score,
|
|
"accuracy_percentage": validation_run.overall_accuracy_percentage
|
|
},
|
|
"total_predicted_demand": validation_run.total_predicted_demand,
|
|
"total_actual_demand": validation_run.total_actual_demand,
|
|
"duration_seconds": validation_run.duration_seconds,
|
|
"poor_accuracy_products": poor_accuracy_products,
|
|
"metrics_by_product": validation_run.metrics_by_product,
|
|
"metrics_by_location": validation_run.metrics_by_location
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Forecast validation failed",
|
|
tenant_id=tenant_id,
|
|
error=str(e),
|
|
error_type=type(e).__name__
|
|
)
|
|
|
|
if validation_run:
|
|
validation_run.status = "failed"
|
|
validation_run.error_message = str(e)
|
|
validation_run.error_details = {"error_type": type(e).__name__}
|
|
validation_run.completed_at = datetime.now(timezone.utc)
|
|
validation_run.duration_seconds = (
|
|
validation_run.completed_at - validation_run.started_at
|
|
).total_seconds()
|
|
await self.db.commit()
|
|
|
|
raise DatabaseError(f"Forecast validation failed: {str(e)}")
|
|
|
|
async def validate_yesterday(
|
|
self,
|
|
tenant_id: uuid.UUID,
|
|
orchestration_run_id: Optional[uuid.UUID] = None,
|
|
triggered_by: str = "orchestrator"
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Convenience method to validate yesterday's forecasts
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
orchestration_run_id: Optional link to orchestration run
|
|
triggered_by: How this validation was triggered
|
|
|
|
Returns:
|
|
Dictionary with validation results
|
|
"""
|
|
yesterday = datetime.now(timezone.utc) - timedelta(days=1)
|
|
start_date = yesterday.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
end_date = yesterday.replace(hour=23, minute=59, second=59, microsecond=999999)
|
|
|
|
return await self.validate_date_range(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
orchestration_run_id=orchestration_run_id,
|
|
triggered_by=triggered_by
|
|
)
|
|
|
|
async def _fetch_forecasts_with_sales(
|
|
self,
|
|
tenant_id: uuid.UUID,
|
|
start_date: datetime,
|
|
end_date: datetime
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Fetch forecasts with their corresponding actual sales data
|
|
|
|
Returns list of dictionaries containing forecast and sales data
|
|
"""
|
|
try:
|
|
# Import here to avoid circular dependency
|
|
from app.services.sales_client import SalesClient
|
|
|
|
# Query to get all forecasts in the date range
|
|
query = select(Forecast).where(
|
|
and_(
|
|
Forecast.tenant_id == tenant_id,
|
|
func.cast(Forecast.forecast_date, Date) >= start_date.date(),
|
|
func.cast(Forecast.forecast_date, Date) <= end_date.date()
|
|
)
|
|
).order_by(Forecast.forecast_date, Forecast.inventory_product_id)
|
|
|
|
result = await self.db.execute(query)
|
|
forecasts = result.scalars().all()
|
|
|
|
if not forecasts:
|
|
return []
|
|
|
|
# Fetch actual sales data from sales service
|
|
sales_client = SalesClient()
|
|
sales_data = await sales_client.get_sales_by_date_range(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date
|
|
)
|
|
|
|
# Create lookup dict: (product_id, date) -> sales quantity
|
|
sales_lookup = {}
|
|
for sale in sales_data:
|
|
sale_date = sale['date']
|
|
if isinstance(sale_date, str):
|
|
sale_date = datetime.fromisoformat(sale_date.replace('Z', '+00:00'))
|
|
|
|
key = (str(sale['inventory_product_id']), sale_date.date())
|
|
|
|
# Sum quantities if multiple sales records for same product/date
|
|
if key in sales_lookup:
|
|
sales_lookup[key]['quantity_sold'] += sale['quantity_sold']
|
|
else:
|
|
sales_lookup[key] = sale
|
|
|
|
# Match forecasts with sales data
|
|
forecasts_with_sales = []
|
|
for forecast in forecasts:
|
|
forecast_date = forecast.forecast_date.date() if hasattr(forecast.forecast_date, 'date') else forecast.forecast_date
|
|
key = (str(forecast.inventory_product_id), forecast_date)
|
|
|
|
sales_record = sales_lookup.get(key)
|
|
|
|
forecasts_with_sales.append({
|
|
"forecast_id": forecast.id,
|
|
"tenant_id": forecast.tenant_id,
|
|
"inventory_product_id": forecast.inventory_product_id,
|
|
"product_name": forecast.product_name,
|
|
"location": forecast.location,
|
|
"forecast_date": forecast.forecast_date,
|
|
"predicted_demand": forecast.predicted_demand,
|
|
"confidence_lower": forecast.confidence_lower,
|
|
"confidence_upper": forecast.confidence_upper,
|
|
"model_id": forecast.model_id,
|
|
"model_version": forecast.model_version,
|
|
"algorithm": forecast.algorithm,
|
|
"created_at": forecast.created_at,
|
|
"actual_sales": sales_record['quantity_sold'] if sales_record else None,
|
|
"has_actual_data": sales_record is not None
|
|
})
|
|
|
|
return forecasts_with_sales
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to fetch forecasts with sales data",
|
|
tenant_id=tenant_id,
|
|
error=str(e)
|
|
)
|
|
raise DatabaseError(f"Failed to fetch forecast and sales data: {str(e)}")
|
|
|
|
async def _calculate_and_store_metrics(
|
|
self,
|
|
forecasts_with_sales: List[Dict[str, Any]],
|
|
validation_run_id: uuid.UUID
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Calculate accuracy metrics and store them in the database
|
|
|
|
Returns summary of metrics calculated
|
|
"""
|
|
# Separate forecasts with and without actual data
|
|
forecasts_with_actuals = [f for f in forecasts_with_sales if f["has_actual_data"]]
|
|
forecasts_without_actuals = [f for f in forecasts_with_sales if not f["has_actual_data"]]
|
|
|
|
if not forecasts_with_actuals:
|
|
return {
|
|
"total_evaluated": len(forecasts_with_sales),
|
|
"with_actuals": 0,
|
|
"without_actuals": len(forecasts_without_actuals),
|
|
"metrics_created": 0,
|
|
"overall_mae": None,
|
|
"overall_mape": None,
|
|
"overall_rmse": None,
|
|
"overall_r2_score": None,
|
|
"overall_accuracy_percentage": None,
|
|
"total_predicted": 0.0,
|
|
"total_actual": 0.0,
|
|
"metrics_by_product": {},
|
|
"metrics_by_location": {}
|
|
}
|
|
|
|
# Calculate individual metrics and prepare for bulk insert
|
|
performance_metrics = []
|
|
errors = []
|
|
|
|
for forecast in forecasts_with_actuals:
|
|
predicted = forecast["predicted_demand"]
|
|
actual = forecast["actual_sales"]
|
|
|
|
# Calculate error
|
|
error = abs(predicted - actual)
|
|
errors.append(error)
|
|
|
|
# Calculate percentage error (avoid division by zero)
|
|
percentage_error = (error / actual * 100) if actual > 0 else 0.0
|
|
|
|
# Calculate individual metrics
|
|
mae = error
|
|
mape = percentage_error
|
|
rmse = error # Will be squared and averaged later
|
|
|
|
# Calculate accuracy percentage (100% - MAPE, capped at 0)
|
|
accuracy_percentage = max(0.0, 100.0 - mape)
|
|
|
|
# Create performance metric record
|
|
metric = ModelPerformanceMetric(
|
|
model_id=uuid.UUID(forecast["model_id"]) if isinstance(forecast["model_id"], str) else forecast["model_id"],
|
|
tenant_id=forecast["tenant_id"],
|
|
inventory_product_id=forecast["inventory_product_id"],
|
|
mae=mae,
|
|
mape=mape,
|
|
rmse=rmse,
|
|
accuracy_score=accuracy_percentage / 100.0, # Store as 0-1 scale
|
|
evaluation_date=forecast["forecast_date"],
|
|
evaluation_period_start=forecast["forecast_date"],
|
|
evaluation_period_end=forecast["forecast_date"],
|
|
sample_size=1
|
|
)
|
|
performance_metrics.append(metric)
|
|
|
|
# Bulk insert all performance metrics
|
|
if performance_metrics:
|
|
self.db.add_all(performance_metrics)
|
|
await self.db.flush()
|
|
|
|
# Calculate overall metrics
|
|
overall_metrics = self._calculate_overall_metrics(forecasts_with_actuals)
|
|
|
|
# Calculate metrics by product
|
|
metrics_by_product = self._calculate_metrics_by_dimension(
|
|
forecasts_with_actuals, "inventory_product_id"
|
|
)
|
|
|
|
# Calculate metrics by location
|
|
metrics_by_location = self._calculate_metrics_by_dimension(
|
|
forecasts_with_actuals, "location"
|
|
)
|
|
|
|
return {
|
|
"total_evaluated": len(forecasts_with_sales),
|
|
"with_actuals": len(forecasts_with_actuals),
|
|
"without_actuals": len(forecasts_without_actuals),
|
|
"metrics_created": len(performance_metrics),
|
|
"overall_mae": overall_metrics["mae"],
|
|
"overall_mape": overall_metrics["mape"],
|
|
"overall_rmse": overall_metrics["rmse"],
|
|
"overall_r2_score": overall_metrics["r2_score"],
|
|
"overall_accuracy_percentage": overall_metrics["accuracy_percentage"],
|
|
"total_predicted": overall_metrics["total_predicted"],
|
|
"total_actual": overall_metrics["total_actual"],
|
|
"metrics_by_product": metrics_by_product,
|
|
"metrics_by_location": metrics_by_location
|
|
}
|
|
|
|
def _calculate_overall_metrics(self, forecasts: List[Dict[str, Any]]) -> Dict[str, float]:
|
|
"""Calculate aggregated metrics across all forecasts"""
|
|
if not forecasts:
|
|
return {
|
|
"mae": None, "mape": None, "rmse": None,
|
|
"r2_score": None, "accuracy_percentage": None,
|
|
"total_predicted": 0.0, "total_actual": 0.0
|
|
}
|
|
|
|
predicted_values = [f["predicted_demand"] for f in forecasts]
|
|
actual_values = [f["actual_sales"] for f in forecasts]
|
|
|
|
# MAE: Mean Absolute Error
|
|
mae = sum(abs(p - a) for p, a in zip(predicted_values, actual_values)) / len(forecasts)
|
|
|
|
# MAPE: Mean Absolute Percentage Error (handle division by zero)
|
|
mape_values = [
|
|
abs(p - a) / a * 100 if a > 0 else 0.0
|
|
for p, a in zip(predicted_values, actual_values)
|
|
]
|
|
mape = sum(mape_values) / len(mape_values)
|
|
|
|
# RMSE: Root Mean Square Error
|
|
squared_errors = [(p - a) ** 2 for p, a in zip(predicted_values, actual_values)]
|
|
rmse = math.sqrt(sum(squared_errors) / len(squared_errors))
|
|
|
|
# R² Score (coefficient of determination)
|
|
mean_actual = sum(actual_values) / len(actual_values)
|
|
ss_total = sum((a - mean_actual) ** 2 for a in actual_values)
|
|
ss_residual = sum((a - p) ** 2 for a, p in zip(actual_values, predicted_values))
|
|
r2_score = 1 - (ss_residual / ss_total) if ss_total > 0 else 0.0
|
|
|
|
# Accuracy percentage (100% - MAPE)
|
|
accuracy_percentage = max(0.0, 100.0 - mape)
|
|
|
|
return {
|
|
"mae": round(mae, 2),
|
|
"mape": round(mape, 2),
|
|
"rmse": round(rmse, 2),
|
|
"r2_score": round(r2_score, 4),
|
|
"accuracy_percentage": round(accuracy_percentage, 2),
|
|
"total_predicted": round(sum(predicted_values), 2),
|
|
"total_actual": round(sum(actual_values), 2)
|
|
}
|
|
|
|
def _calculate_metrics_by_dimension(
|
|
self,
|
|
forecasts: List[Dict[str, Any]],
|
|
dimension_key: str
|
|
) -> Dict[str, Dict[str, float]]:
|
|
"""Calculate metrics grouped by a dimension (product_id or location)"""
|
|
dimension_groups = {}
|
|
|
|
# Group forecasts by dimension
|
|
for forecast in forecasts:
|
|
key = str(forecast[dimension_key])
|
|
if key not in dimension_groups:
|
|
dimension_groups[key] = []
|
|
dimension_groups[key].append(forecast)
|
|
|
|
# Calculate metrics for each group
|
|
metrics_by_dimension = {}
|
|
for key, group_forecasts in dimension_groups.items():
|
|
metrics = self._calculate_overall_metrics(group_forecasts)
|
|
metrics_by_dimension[key] = {
|
|
"count": len(group_forecasts),
|
|
"mae": metrics["mae"],
|
|
"mape": metrics["mape"],
|
|
"rmse": metrics["rmse"],
|
|
"accuracy_percentage": metrics["accuracy_percentage"],
|
|
"total_predicted": metrics["total_predicted"],
|
|
"total_actual": metrics["total_actual"]
|
|
}
|
|
|
|
return metrics_by_dimension
|
|
|
|
async def get_validation_run(self, validation_run_id: uuid.UUID) -> Optional[ValidationRun]:
|
|
"""Get a validation run by ID"""
|
|
try:
|
|
query = select(ValidationRun).where(ValidationRun.id == validation_run_id)
|
|
result = await self.db.execute(query)
|
|
return result.scalar_one_or_none()
|
|
except Exception as e:
|
|
logger.error("Failed to get validation run", validation_run_id=validation_run_id, error=str(e))
|
|
raise DatabaseError(f"Failed to get validation run: {str(e)}")
|
|
|
|
async def get_validation_runs_by_tenant(
|
|
self,
|
|
tenant_id: uuid.UUID,
|
|
limit: int = 50,
|
|
skip: int = 0
|
|
) -> List[ValidationRun]:
|
|
"""Get validation runs for a tenant"""
|
|
try:
|
|
query = (
|
|
select(ValidationRun)
|
|
.where(ValidationRun.tenant_id == tenant_id)
|
|
.order_by(ValidationRun.created_at.desc())
|
|
.limit(limit)
|
|
.offset(skip)
|
|
)
|
|
result = await self.db.execute(query)
|
|
return result.scalars().all()
|
|
except Exception as e:
|
|
logger.error("Failed to get validation runs", tenant_id=tenant_id, error=str(e))
|
|
raise DatabaseError(f"Failed to get validation runs: {str(e)}")
|
|
|
|
async def get_accuracy_trends(
|
|
self,
|
|
tenant_id: uuid.UUID,
|
|
days: int = 30
|
|
) -> Dict[str, Any]:
|
|
"""Get accuracy trends over time"""
|
|
try:
|
|
start_date = datetime.now(timezone.utc) - timedelta(days=days)
|
|
|
|
query = (
|
|
select(ValidationRun)
|
|
.where(
|
|
and_(
|
|
ValidationRun.tenant_id == tenant_id,
|
|
ValidationRun.status == "completed",
|
|
ValidationRun.created_at >= start_date
|
|
)
|
|
)
|
|
.order_by(ValidationRun.created_at)
|
|
)
|
|
|
|
result = await self.db.execute(query)
|
|
runs = result.scalars().all()
|
|
|
|
if not runs:
|
|
return {
|
|
"period_days": days,
|
|
"total_runs": 0,
|
|
"trends": []
|
|
}
|
|
|
|
trends = [
|
|
{
|
|
"date": run.validation_start_date.isoformat(),
|
|
"mae": run.overall_mae,
|
|
"mape": run.overall_mape,
|
|
"rmse": run.overall_rmse,
|
|
"accuracy_percentage": run.overall_accuracy_percentage,
|
|
"forecasts_evaluated": run.total_forecasts_evaluated,
|
|
"forecasts_with_actuals": run.forecasts_with_actuals
|
|
}
|
|
for run in runs
|
|
]
|
|
|
|
# Calculate averages
|
|
valid_runs = [r for r in runs if r.overall_mape is not None]
|
|
avg_mape = sum(r.overall_mape for r in valid_runs) / len(valid_runs) if valid_runs else None
|
|
avg_accuracy = sum(r.overall_accuracy_percentage for r in valid_runs) / len(valid_runs) if valid_runs else None
|
|
|
|
return {
|
|
"period_days": days,
|
|
"total_runs": len(runs),
|
|
"average_mape": round(avg_mape, 2) if avg_mape else None,
|
|
"average_accuracy": round(avg_accuracy, 2) if avg_accuracy else None,
|
|
"trends": trends
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get accuracy trends", tenant_id=tenant_id, error=str(e))
|
|
raise DatabaseError(f"Failed to get accuracy trends: {str(e)}")
|