Files
bakery-ia/services/forecasting/app/services/validation_service.py
2025-11-18 07:17:17 +01:00

587 lines
24 KiB
Python

# ================================================================
# services/forecasting/app/services/validation_service.py
# ================================================================
"""
Forecast Validation Service
Compares historical forecasts with actual sales data to:
1. Calculate accuracy metrics (MAE, MAPE, RMSE, R², accuracy percentage)
2. Store performance metrics in the database
3. Track validation runs for audit purposes
4. Enable continuous model improvement
"""
from typing import Dict, Any, List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, func, Date
from datetime import datetime, timedelta, timezone
import structlog
import math
import uuid
from app.models.forecasts import Forecast
from app.models.predictions import ModelPerformanceMetric
from app.models.validation_run import ValidationRun
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class ValidationService:
"""Service for validating forecasts against actual sales data"""
def __init__(self, db_session: AsyncSession):
self.db = db_session
async def validate_date_range(
self,
tenant_id: uuid.UUID,
start_date: datetime,
end_date: datetime,
orchestration_run_id: Optional[uuid.UUID] = None,
triggered_by: str = "manual"
) -> Dict[str, Any]:
"""
Validate forecasts against actual sales for a date range
Args:
tenant_id: Tenant identifier
start_date: Start of validation period
end_date: End of validation period
orchestration_run_id: Optional link to orchestration run
triggered_by: How this validation was triggered (manual, orchestrator, scheduled)
Returns:
Dictionary with validation results and metrics
"""
validation_run = None
try:
# Create validation run record
validation_run = ValidationRun(
tenant_id=tenant_id,
orchestration_run_id=orchestration_run_id,
validation_start_date=start_date,
validation_end_date=end_date,
status="running",
triggered_by=triggered_by,
execution_mode="batch"
)
self.db.add(validation_run)
await self.db.flush()
logger.info(
"Starting forecast validation",
validation_run_id=validation_run.id,
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat()
)
# Fetch forecasts with matching sales data
forecasts_with_sales = await self._fetch_forecasts_with_sales(
tenant_id, start_date, end_date
)
if not forecasts_with_sales:
logger.warning(
"No forecasts with matching sales data found",
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat()
)
validation_run.status = "completed"
validation_run.completed_at = datetime.now(timezone.utc)
validation_run.duration_seconds = (
validation_run.completed_at - validation_run.started_at
).total_seconds()
await self.db.commit()
return {
"validation_run_id": str(validation_run.id),
"status": "completed",
"message": "No forecasts with matching sales data found",
"forecasts_evaluated": 0,
"metrics_created": 0
}
# Calculate metrics and create performance records
metrics_results = await self._calculate_and_store_metrics(
forecasts_with_sales, validation_run.id
)
# Update validation run with results
validation_run.total_forecasts_evaluated = metrics_results["total_evaluated"]
validation_run.forecasts_with_actuals = metrics_results["with_actuals"]
validation_run.forecasts_without_actuals = metrics_results["without_actuals"]
validation_run.overall_mae = metrics_results["overall_mae"]
validation_run.overall_mape = metrics_results["overall_mape"]
validation_run.overall_rmse = metrics_results["overall_rmse"]
validation_run.overall_r2_score = metrics_results["overall_r2_score"]
validation_run.overall_accuracy_percentage = metrics_results["overall_accuracy_percentage"]
validation_run.total_predicted_demand = metrics_results["total_predicted"]
validation_run.total_actual_demand = metrics_results["total_actual"]
validation_run.metrics_by_product = metrics_results["metrics_by_product"]
validation_run.metrics_by_location = metrics_results["metrics_by_location"]
validation_run.metrics_records_created = metrics_results["metrics_created"]
validation_run.status = "completed"
validation_run.completed_at = datetime.now(timezone.utc)
validation_run.duration_seconds = (
validation_run.completed_at - validation_run.started_at
).total_seconds()
await self.db.commit()
logger.info(
"Forecast validation completed successfully",
validation_run_id=validation_run.id,
forecasts_evaluated=validation_run.total_forecasts_evaluated,
metrics_created=validation_run.metrics_records_created,
overall_mape=validation_run.overall_mape,
duration_seconds=validation_run.duration_seconds
)
# Extract poor accuracy products (MAPE > 30%)
poor_accuracy_products = []
if validation_run.metrics_by_product:
for product_id, product_metrics in validation_run.metrics_by_product.items():
if product_metrics.get("mape", 0) > 30:
poor_accuracy_products.append({
"product_id": product_id,
"mape": product_metrics.get("mape"),
"mae": product_metrics.get("mae"),
"accuracy_percentage": product_metrics.get("accuracy_percentage")
})
return {
"validation_run_id": str(validation_run.id),
"status": "completed",
"forecasts_evaluated": validation_run.total_forecasts_evaluated,
"forecasts_with_actuals": validation_run.forecasts_with_actuals,
"forecasts_without_actuals": validation_run.forecasts_without_actuals,
"metrics_created": validation_run.metrics_records_created,
"overall_metrics": {
"mae": validation_run.overall_mae,
"mape": validation_run.overall_mape,
"rmse": validation_run.overall_rmse,
"r2_score": validation_run.overall_r2_score,
"accuracy_percentage": validation_run.overall_accuracy_percentage
},
"total_predicted_demand": validation_run.total_predicted_demand,
"total_actual_demand": validation_run.total_actual_demand,
"duration_seconds": validation_run.duration_seconds,
"poor_accuracy_products": poor_accuracy_products,
"metrics_by_product": validation_run.metrics_by_product,
"metrics_by_location": validation_run.metrics_by_location
}
except Exception as e:
logger.error(
"Forecast validation failed",
tenant_id=tenant_id,
error=str(e),
error_type=type(e).__name__
)
if validation_run:
validation_run.status = "failed"
validation_run.error_message = str(e)
validation_run.error_details = {"error_type": type(e).__name__}
validation_run.completed_at = datetime.now(timezone.utc)
validation_run.duration_seconds = (
validation_run.completed_at - validation_run.started_at
).total_seconds()
await self.db.commit()
raise DatabaseError(f"Forecast validation failed: {str(e)}")
async def validate_yesterday(
self,
tenant_id: uuid.UUID,
orchestration_run_id: Optional[uuid.UUID] = None,
triggered_by: str = "orchestrator"
) -> Dict[str, Any]:
"""
Convenience method to validate yesterday's forecasts
Args:
tenant_id: Tenant identifier
orchestration_run_id: Optional link to orchestration run
triggered_by: How this validation was triggered
Returns:
Dictionary with validation results
"""
yesterday = datetime.now(timezone.utc) - timedelta(days=1)
start_date = yesterday.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = yesterday.replace(hour=23, minute=59, second=59, microsecond=999999)
return await self.validate_date_range(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
orchestration_run_id=orchestration_run_id,
triggered_by=triggered_by
)
async def _fetch_forecasts_with_sales(
self,
tenant_id: uuid.UUID,
start_date: datetime,
end_date: datetime
) -> List[Dict[str, Any]]:
"""
Fetch forecasts with their corresponding actual sales data
Returns list of dictionaries containing forecast and sales data
"""
try:
# Import here to avoid circular dependency
from app.services.sales_client import SalesClient
# Query to get all forecasts in the date range
query = select(Forecast).where(
and_(
Forecast.tenant_id == tenant_id,
func.cast(Forecast.forecast_date, Date) >= start_date.date(),
func.cast(Forecast.forecast_date, Date) <= end_date.date()
)
).order_by(Forecast.forecast_date, Forecast.inventory_product_id)
result = await self.db.execute(query)
forecasts = result.scalars().all()
if not forecasts:
return []
# Fetch actual sales data from sales service
sales_client = SalesClient()
sales_data = await sales_client.get_sales_by_date_range(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
# Create lookup dict: (product_id, date) -> sales quantity
sales_lookup = {}
for sale in sales_data:
sale_date = sale['date']
if isinstance(sale_date, str):
sale_date = datetime.fromisoformat(sale_date.replace('Z', '+00:00'))
key = (str(sale['inventory_product_id']), sale_date.date())
# Sum quantities if multiple sales records for same product/date
if key in sales_lookup:
sales_lookup[key]['quantity_sold'] += sale['quantity_sold']
else:
sales_lookup[key] = sale
# Match forecasts with sales data
forecasts_with_sales = []
for forecast in forecasts:
forecast_date = forecast.forecast_date.date() if hasattr(forecast.forecast_date, 'date') else forecast.forecast_date
key = (str(forecast.inventory_product_id), forecast_date)
sales_record = sales_lookup.get(key)
forecasts_with_sales.append({
"forecast_id": forecast.id,
"tenant_id": forecast.tenant_id,
"inventory_product_id": forecast.inventory_product_id,
"product_name": forecast.product_name,
"location": forecast.location,
"forecast_date": forecast.forecast_date,
"predicted_demand": forecast.predicted_demand,
"confidence_lower": forecast.confidence_lower,
"confidence_upper": forecast.confidence_upper,
"model_id": forecast.model_id,
"model_version": forecast.model_version,
"algorithm": forecast.algorithm,
"created_at": forecast.created_at,
"actual_sales": sales_record['quantity_sold'] if sales_record else None,
"has_actual_data": sales_record is not None
})
return forecasts_with_sales
except Exception as e:
logger.error(
"Failed to fetch forecasts with sales data",
tenant_id=tenant_id,
error=str(e)
)
raise DatabaseError(f"Failed to fetch forecast and sales data: {str(e)}")
async def _calculate_and_store_metrics(
self,
forecasts_with_sales: List[Dict[str, Any]],
validation_run_id: uuid.UUID
) -> Dict[str, Any]:
"""
Calculate accuracy metrics and store them in the database
Returns summary of metrics calculated
"""
# Separate forecasts with and without actual data
forecasts_with_actuals = [f for f in forecasts_with_sales if f["has_actual_data"]]
forecasts_without_actuals = [f for f in forecasts_with_sales if not f["has_actual_data"]]
if not forecasts_with_actuals:
return {
"total_evaluated": len(forecasts_with_sales),
"with_actuals": 0,
"without_actuals": len(forecasts_without_actuals),
"metrics_created": 0,
"overall_mae": None,
"overall_mape": None,
"overall_rmse": None,
"overall_r2_score": None,
"overall_accuracy_percentage": None,
"total_predicted": 0.0,
"total_actual": 0.0,
"metrics_by_product": {},
"metrics_by_location": {}
}
# Calculate individual metrics and prepare for bulk insert
performance_metrics = []
errors = []
for forecast in forecasts_with_actuals:
predicted = forecast["predicted_demand"]
actual = forecast["actual_sales"]
# Calculate error
error = abs(predicted - actual)
errors.append(error)
# Calculate percentage error (avoid division by zero)
percentage_error = (error / actual * 100) if actual > 0 else 0.0
# Calculate individual metrics
mae = error
mape = percentage_error
rmse = error # Will be squared and averaged later
# Calculate accuracy percentage (100% - MAPE, capped at 0)
accuracy_percentage = max(0.0, 100.0 - mape)
# Create performance metric record
metric = ModelPerformanceMetric(
model_id=uuid.UUID(forecast["model_id"]) if isinstance(forecast["model_id"], str) else forecast["model_id"],
tenant_id=forecast["tenant_id"],
inventory_product_id=forecast["inventory_product_id"],
mae=mae,
mape=mape,
rmse=rmse,
accuracy_score=accuracy_percentage / 100.0, # Store as 0-1 scale
evaluation_date=forecast["forecast_date"],
evaluation_period_start=forecast["forecast_date"],
evaluation_period_end=forecast["forecast_date"],
sample_size=1
)
performance_metrics.append(metric)
# Bulk insert all performance metrics
if performance_metrics:
self.db.add_all(performance_metrics)
await self.db.flush()
# Calculate overall metrics
overall_metrics = self._calculate_overall_metrics(forecasts_with_actuals)
# Calculate metrics by product
metrics_by_product = self._calculate_metrics_by_dimension(
forecasts_with_actuals, "inventory_product_id"
)
# Calculate metrics by location
metrics_by_location = self._calculate_metrics_by_dimension(
forecasts_with_actuals, "location"
)
return {
"total_evaluated": len(forecasts_with_sales),
"with_actuals": len(forecasts_with_actuals),
"without_actuals": len(forecasts_without_actuals),
"metrics_created": len(performance_metrics),
"overall_mae": overall_metrics["mae"],
"overall_mape": overall_metrics["mape"],
"overall_rmse": overall_metrics["rmse"],
"overall_r2_score": overall_metrics["r2_score"],
"overall_accuracy_percentage": overall_metrics["accuracy_percentage"],
"total_predicted": overall_metrics["total_predicted"],
"total_actual": overall_metrics["total_actual"],
"metrics_by_product": metrics_by_product,
"metrics_by_location": metrics_by_location
}
def _calculate_overall_metrics(self, forecasts: List[Dict[str, Any]]) -> Dict[str, float]:
"""Calculate aggregated metrics across all forecasts"""
if not forecasts:
return {
"mae": None, "mape": None, "rmse": None,
"r2_score": None, "accuracy_percentage": None,
"total_predicted": 0.0, "total_actual": 0.0
}
predicted_values = [f["predicted_demand"] for f in forecasts]
actual_values = [f["actual_sales"] for f in forecasts]
# MAE: Mean Absolute Error
mae = sum(abs(p - a) for p, a in zip(predicted_values, actual_values)) / len(forecasts)
# MAPE: Mean Absolute Percentage Error (handle division by zero)
mape_values = [
abs(p - a) / a * 100 if a > 0 else 0.0
for p, a in zip(predicted_values, actual_values)
]
mape = sum(mape_values) / len(mape_values)
# RMSE: Root Mean Square Error
squared_errors = [(p - a) ** 2 for p, a in zip(predicted_values, actual_values)]
rmse = math.sqrt(sum(squared_errors) / len(squared_errors))
# R² Score (coefficient of determination)
mean_actual = sum(actual_values) / len(actual_values)
ss_total = sum((a - mean_actual) ** 2 for a in actual_values)
ss_residual = sum((a - p) ** 2 for a, p in zip(actual_values, predicted_values))
r2_score = 1 - (ss_residual / ss_total) if ss_total > 0 else 0.0
# Accuracy percentage (100% - MAPE)
accuracy_percentage = max(0.0, 100.0 - mape)
return {
"mae": round(mae, 2),
"mape": round(mape, 2),
"rmse": round(rmse, 2),
"r2_score": round(r2_score, 4),
"accuracy_percentage": round(accuracy_percentage, 2),
"total_predicted": round(sum(predicted_values), 2),
"total_actual": round(sum(actual_values), 2)
}
def _calculate_metrics_by_dimension(
self,
forecasts: List[Dict[str, Any]],
dimension_key: str
) -> Dict[str, Dict[str, float]]:
"""Calculate metrics grouped by a dimension (product_id or location)"""
dimension_groups = {}
# Group forecasts by dimension
for forecast in forecasts:
key = str(forecast[dimension_key])
if key not in dimension_groups:
dimension_groups[key] = []
dimension_groups[key].append(forecast)
# Calculate metrics for each group
metrics_by_dimension = {}
for key, group_forecasts in dimension_groups.items():
metrics = self._calculate_overall_metrics(group_forecasts)
metrics_by_dimension[key] = {
"count": len(group_forecasts),
"mae": metrics["mae"],
"mape": metrics["mape"],
"rmse": metrics["rmse"],
"accuracy_percentage": metrics["accuracy_percentage"],
"total_predicted": metrics["total_predicted"],
"total_actual": metrics["total_actual"]
}
return metrics_by_dimension
async def get_validation_run(self, validation_run_id: uuid.UUID) -> Optional[ValidationRun]:
"""Get a validation run by ID"""
try:
query = select(ValidationRun).where(ValidationRun.id == validation_run_id)
result = await self.db.execute(query)
return result.scalar_one_or_none()
except Exception as e:
logger.error("Failed to get validation run", validation_run_id=validation_run_id, error=str(e))
raise DatabaseError(f"Failed to get validation run: {str(e)}")
async def get_validation_runs_by_tenant(
self,
tenant_id: uuid.UUID,
limit: int = 50,
skip: int = 0
) -> List[ValidationRun]:
"""Get validation runs for a tenant"""
try:
query = (
select(ValidationRun)
.where(ValidationRun.tenant_id == tenant_id)
.order_by(ValidationRun.created_at.desc())
.limit(limit)
.offset(skip)
)
result = await self.db.execute(query)
return result.scalars().all()
except Exception as e:
logger.error("Failed to get validation runs", tenant_id=tenant_id, error=str(e))
raise DatabaseError(f"Failed to get validation runs: {str(e)}")
async def get_accuracy_trends(
self,
tenant_id: uuid.UUID,
days: int = 30
) -> Dict[str, Any]:
"""Get accuracy trends over time"""
try:
start_date = datetime.now(timezone.utc) - timedelta(days=days)
query = (
select(ValidationRun)
.where(
and_(
ValidationRun.tenant_id == tenant_id,
ValidationRun.status == "completed",
ValidationRun.created_at >= start_date
)
)
.order_by(ValidationRun.created_at)
)
result = await self.db.execute(query)
runs = result.scalars().all()
if not runs:
return {
"period_days": days,
"total_runs": 0,
"trends": []
}
trends = [
{
"date": run.validation_start_date.isoformat(),
"mae": run.overall_mae,
"mape": run.overall_mape,
"rmse": run.overall_rmse,
"accuracy_percentage": run.overall_accuracy_percentage,
"forecasts_evaluated": run.total_forecasts_evaluated,
"forecasts_with_actuals": run.forecasts_with_actuals
}
for run in runs
]
# Calculate averages
valid_runs = [r for r in runs if r.overall_mape is not None]
avg_mape = sum(r.overall_mape for r in valid_runs) / len(valid_runs) if valid_runs else None
avg_accuracy = sum(r.overall_accuracy_percentage for r in valid_runs) / len(valid_runs) if valid_runs else None
return {
"period_days": days,
"total_runs": len(runs),
"average_mape": round(avg_mape, 2) if avg_mape else None,
"average_accuracy": round(avg_accuracy, 2) if avg_accuracy else None,
"trends": trends
}
except Exception as e:
logger.error("Failed to get accuracy trends", tenant_id=tenant_id, error=str(e))
raise DatabaseError(f"Failed to get accuracy trends: {str(e)}")