Improve backend
This commit is contained in:
@@ -0,0 +1,480 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/historical_validation_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Historical Validation Service
|
||||
|
||||
Handles validation backfill when historical sales data is uploaded late.
|
||||
Detects gaps in validation coverage and automatically triggers validation
|
||||
for periods where forecasts exist but haven't been validated yet.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, func, Date, or_
|
||||
from datetime import datetime, timedelta, timezone, date
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.models.forecasts import Forecast
|
||||
from app.models.validation_run import ValidationRun
|
||||
from app.models.sales_data_update import SalesDataUpdate
|
||||
from app.services.validation_service import ValidationService
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class HistoricalValidationService:
|
||||
"""Service for backfilling historical validation when sales data arrives late"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
self.validation_service = ValidationService(db_session)
|
||||
|
||||
async def detect_validation_gaps(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
lookback_days: int = 90
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Detect date ranges where forecasts exist but haven't been validated
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
lookback_days: How far back to check (default 90 days)
|
||||
|
||||
Returns:
|
||||
List of gap periods with date ranges
|
||||
"""
|
||||
try:
|
||||
end_date = datetime.now(timezone.utc)
|
||||
start_date = end_date - timedelta(days=lookback_days)
|
||||
|
||||
logger.info(
|
||||
"Detecting validation gaps",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
|
||||
# Get all dates with forecasts
|
||||
forecast_query = select(
|
||||
func.cast(Forecast.forecast_date, Date).label('forecast_date')
|
||||
).where(
|
||||
and_(
|
||||
Forecast.tenant_id == tenant_id,
|
||||
Forecast.forecast_date >= start_date,
|
||||
Forecast.forecast_date <= end_date
|
||||
)
|
||||
).group_by(
|
||||
func.cast(Forecast.forecast_date, Date)
|
||||
).order_by(
|
||||
func.cast(Forecast.forecast_date, Date)
|
||||
)
|
||||
|
||||
forecast_result = await self.db.execute(forecast_query)
|
||||
forecast_dates = {row.forecast_date for row in forecast_result.fetchall()}
|
||||
|
||||
if not forecast_dates:
|
||||
logger.info("No forecasts found in lookback period", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
# Get all dates that have been validated
|
||||
validation_query = select(
|
||||
func.cast(ValidationRun.validation_start_date, Date).label('validated_date')
|
||||
).where(
|
||||
and_(
|
||||
ValidationRun.tenant_id == tenant_id,
|
||||
ValidationRun.status == "completed",
|
||||
ValidationRun.validation_start_date >= start_date,
|
||||
ValidationRun.validation_end_date <= end_date
|
||||
)
|
||||
).group_by(
|
||||
func.cast(ValidationRun.validation_start_date, Date)
|
||||
)
|
||||
|
||||
validation_result = await self.db.execute(validation_query)
|
||||
validated_dates = {row.validated_date for row in validation_result.fetchall()}
|
||||
|
||||
# Find gaps (dates with forecasts but no validation)
|
||||
gap_dates = sorted(forecast_dates - validated_dates)
|
||||
|
||||
if not gap_dates:
|
||||
logger.info("No validation gaps found", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
# Group consecutive dates into ranges
|
||||
gaps = []
|
||||
current_gap_start = gap_dates[0]
|
||||
current_gap_end = gap_dates[0]
|
||||
|
||||
for i in range(1, len(gap_dates)):
|
||||
if (gap_dates[i] - current_gap_end).days == 1:
|
||||
# Consecutive date, extend current gap
|
||||
current_gap_end = gap_dates[i]
|
||||
else:
|
||||
# Gap in dates, save current gap and start new one
|
||||
gaps.append({
|
||||
"start_date": current_gap_start,
|
||||
"end_date": current_gap_end,
|
||||
"days_count": (current_gap_end - current_gap_start).days + 1
|
||||
})
|
||||
current_gap_start = gap_dates[i]
|
||||
current_gap_end = gap_dates[i]
|
||||
|
||||
# Don't forget the last gap
|
||||
gaps.append({
|
||||
"start_date": current_gap_start,
|
||||
"end_date": current_gap_end,
|
||||
"days_count": (current_gap_end - current_gap_start).days + 1
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Validation gaps detected",
|
||||
tenant_id=tenant_id,
|
||||
gaps_count=len(gaps),
|
||||
total_days=len(gap_dates)
|
||||
)
|
||||
|
||||
return gaps
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to detect validation gaps",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to detect validation gaps: {str(e)}")
|
||||
|
||||
async def backfill_validation(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
triggered_by: str = "manual",
|
||||
sales_data_update_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Backfill validation for a historical date range
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start date for backfill
|
||||
end_date: End date for backfill
|
||||
triggered_by: How this backfill was triggered
|
||||
sales_data_update_id: Optional link to sales data update record
|
||||
|
||||
Returns:
|
||||
Backfill results with validation summary
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting validation backfill",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat(),
|
||||
triggered_by=triggered_by
|
||||
)
|
||||
|
||||
# Convert dates to datetime
|
||||
start_datetime = datetime.combine(start_date, datetime.min.time()).replace(tzinfo=timezone.utc)
|
||||
end_datetime = datetime.combine(end_date, datetime.max.time()).replace(tzinfo=timezone.utc)
|
||||
|
||||
# Run validation for the date range
|
||||
validation_result = await self.validation_service.validate_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_datetime,
|
||||
end_date=end_datetime,
|
||||
orchestration_run_id=None,
|
||||
triggered_by=triggered_by
|
||||
)
|
||||
|
||||
# Update sales data update record if provided
|
||||
if sales_data_update_id:
|
||||
await self._update_sales_data_record(
|
||||
sales_data_update_id=sales_data_update_id,
|
||||
validation_run_id=uuid.UUID(validation_result["validation_run_id"]),
|
||||
status="completed" if validation_result["status"] == "completed" else "failed"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Validation backfill completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=validation_result.get("validation_run_id"),
|
||||
forecasts_evaluated=validation_result.get("forecasts_evaluated")
|
||||
)
|
||||
|
||||
return {
|
||||
**validation_result,
|
||||
"backfill_date_range": {
|
||||
"start": start_date.isoformat(),
|
||||
"end": end_date.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Validation backfill failed",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat(),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
if sales_data_update_id:
|
||||
await self._update_sales_data_record(
|
||||
sales_data_update_id=sales_data_update_id,
|
||||
validation_run_id=None,
|
||||
status="failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
raise DatabaseError(f"Validation backfill failed: {str(e)}")
|
||||
|
||||
async def auto_backfill_gaps(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
lookback_days: int = 90,
|
||||
max_gaps_to_process: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Automatically detect and backfill validation gaps
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
lookback_days: How far back to check
|
||||
max_gaps_to_process: Maximum number of gaps to process in one run
|
||||
|
||||
Returns:
|
||||
Summary of backfill operations
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting auto backfill",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days
|
||||
)
|
||||
|
||||
# Detect gaps
|
||||
gaps = await self.detect_validation_gaps(tenant_id, lookback_days)
|
||||
|
||||
if not gaps:
|
||||
return {
|
||||
"gaps_found": 0,
|
||||
"gaps_processed": 0,
|
||||
"validations_completed": 0,
|
||||
"message": "No validation gaps found"
|
||||
}
|
||||
|
||||
# Limit number of gaps to process
|
||||
gaps_to_process = gaps[:max_gaps_to_process]
|
||||
|
||||
results = []
|
||||
for gap in gaps_to_process:
|
||||
try:
|
||||
result = await self.backfill_validation(
|
||||
tenant_id=tenant_id,
|
||||
start_date=gap["start_date"],
|
||||
end_date=gap["end_date"],
|
||||
triggered_by="auto_backfill"
|
||||
)
|
||||
results.append({
|
||||
"gap": gap,
|
||||
"result": result,
|
||||
"status": "success"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to backfill gap",
|
||||
gap=gap,
|
||||
error=str(e)
|
||||
)
|
||||
results.append({
|
||||
"gap": gap,
|
||||
"error": str(e),
|
||||
"status": "failed"
|
||||
})
|
||||
|
||||
successful = sum(1 for r in results if r["status"] == "success")
|
||||
|
||||
logger.info(
|
||||
"Auto backfill completed",
|
||||
tenant_id=tenant_id,
|
||||
gaps_found=len(gaps),
|
||||
gaps_processed=len(results),
|
||||
successful=successful
|
||||
)
|
||||
|
||||
return {
|
||||
"gaps_found": len(gaps),
|
||||
"gaps_processed": len(results),
|
||||
"validations_completed": successful,
|
||||
"validations_failed": len(results) - successful,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Auto backfill failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Auto backfill failed: {str(e)}")
|
||||
|
||||
async def register_sales_data_update(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
records_affected: int,
|
||||
update_source: str = "import",
|
||||
import_job_id: Optional[str] = None,
|
||||
auto_trigger_validation: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Register a sales data update and optionally trigger validation
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start date of updated data
|
||||
end_date: End date of updated data
|
||||
records_affected: Number of sales records affected
|
||||
update_source: Source of update (import, manual, pos_sync)
|
||||
import_job_id: Optional import job ID
|
||||
auto_trigger_validation: Whether to automatically trigger validation
|
||||
|
||||
Returns:
|
||||
Update record and validation result if triggered
|
||||
"""
|
||||
try:
|
||||
# Create sales data update record
|
||||
update_record = SalesDataUpdate(
|
||||
tenant_id=tenant_id,
|
||||
update_date_start=start_date,
|
||||
update_date_end=end_date,
|
||||
records_affected=records_affected,
|
||||
update_source=update_source,
|
||||
import_job_id=import_job_id,
|
||||
requires_validation=auto_trigger_validation,
|
||||
validation_status="pending" if auto_trigger_validation else "not_required"
|
||||
)
|
||||
|
||||
self.db.add(update_record)
|
||||
await self.db.flush()
|
||||
|
||||
logger.info(
|
||||
"Registered sales data update",
|
||||
tenant_id=tenant_id,
|
||||
update_id=update_record.id,
|
||||
date_range=f"{start_date} to {end_date}",
|
||||
records_affected=records_affected
|
||||
)
|
||||
|
||||
result = {
|
||||
"update_id": str(update_record.id),
|
||||
"update_record": update_record.to_dict(),
|
||||
"validation_triggered": False
|
||||
}
|
||||
|
||||
# Trigger validation if requested
|
||||
if auto_trigger_validation:
|
||||
try:
|
||||
validation_result = await self.backfill_validation(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
triggered_by="sales_data_update",
|
||||
sales_data_update_id=update_record.id
|
||||
)
|
||||
|
||||
result["validation_triggered"] = True
|
||||
result["validation_result"] = validation_result
|
||||
|
||||
logger.info(
|
||||
"Validation triggered for sales data update",
|
||||
update_id=update_record.id,
|
||||
validation_run_id=validation_result.get("validation_run_id")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger validation for sales data update",
|
||||
update_id=update_record.id,
|
||||
error=str(e)
|
||||
)
|
||||
update_record.validation_status = "failed"
|
||||
update_record.validation_error = str(e)[:500]
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to register sales data update",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
await self.db.rollback()
|
||||
raise DatabaseError(f"Failed to register sales data update: {str(e)}")
|
||||
|
||||
async def _update_sales_data_record(
|
||||
self,
|
||||
sales_data_update_id: uuid.UUID,
|
||||
validation_run_id: Optional[uuid.UUID],
|
||||
status: str,
|
||||
error_message: Optional[str] = None
|
||||
):
|
||||
"""Update sales data update record with validation results"""
|
||||
try:
|
||||
query = select(SalesDataUpdate).where(SalesDataUpdate.id == sales_data_update_id)
|
||||
result = await self.db.execute(query)
|
||||
update_record = result.scalar_one_or_none()
|
||||
|
||||
if update_record:
|
||||
update_record.validation_status = status
|
||||
update_record.validation_run_id = validation_run_id
|
||||
update_record.validated_at = datetime.now(timezone.utc)
|
||||
if error_message:
|
||||
update_record.validation_error = error_message[:500]
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to update sales data record",
|
||||
sales_data_update_id=sales_data_update_id,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def get_pending_validations(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
limit: int = 50
|
||||
) -> List[SalesDataUpdate]:
|
||||
"""Get pending sales data updates that need validation"""
|
||||
try:
|
||||
query = (
|
||||
select(SalesDataUpdate)
|
||||
.where(
|
||||
and_(
|
||||
SalesDataUpdate.tenant_id == tenant_id,
|
||||
SalesDataUpdate.validation_status == "pending",
|
||||
SalesDataUpdate.requires_validation == True
|
||||
)
|
||||
)
|
||||
.order_by(SalesDataUpdate.created_at)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get pending validations",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to get pending validations: {str(e)}")
|
||||
@@ -0,0 +1,435 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/performance_monitoring_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Performance Monitoring Service
|
||||
|
||||
Monitors forecast accuracy over time and triggers actions when
|
||||
performance degrades below acceptable thresholds.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, func, desc
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.models.validation_run import ValidationRun
|
||||
from app.models.predictions import ModelPerformanceMetric
|
||||
from app.models.forecasts import Forecast
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class PerformanceMonitoringService:
|
||||
"""Service for monitoring forecast performance and triggering improvements"""
|
||||
|
||||
# Configurable thresholds
|
||||
MAPE_WARNING_THRESHOLD = 20.0 # Warning if MAPE > 20%
|
||||
MAPE_CRITICAL_THRESHOLD = 30.0 # Critical if MAPE > 30%
|
||||
MAPE_TREND_THRESHOLD = 5.0 # Alert if MAPE increases by > 5% over period
|
||||
MIN_SAMPLES_FOR_ALERT = 5 # Minimum validations before alerting
|
||||
TREND_LOOKBACK_DAYS = 30 # Days to analyze for trends
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
|
||||
async def get_accuracy_summary(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get accuracy summary for recent period
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
days: Number of days to analyze
|
||||
|
||||
Returns:
|
||||
Summary with overall metrics and trends
|
||||
"""
|
||||
try:
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
|
||||
# Get recent validation runs
|
||||
query = (
|
||||
select(ValidationRun)
|
||||
.where(
|
||||
and_(
|
||||
ValidationRun.tenant_id == tenant_id,
|
||||
ValidationRun.status == "completed",
|
||||
ValidationRun.started_at >= start_date,
|
||||
ValidationRun.forecasts_with_actuals > 0 # Only runs with actual data
|
||||
)
|
||||
)
|
||||
.order_by(desc(ValidationRun.started_at))
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
runs = result.scalars().all()
|
||||
|
||||
if not runs:
|
||||
return {
|
||||
"status": "no_data",
|
||||
"message": f"No validation runs found in last {days} days",
|
||||
"period_days": days
|
||||
}
|
||||
|
||||
# Calculate summary statistics
|
||||
total_forecasts = sum(r.total_forecasts_evaluated for r in runs)
|
||||
total_with_actuals = sum(r.forecasts_with_actuals for r in runs)
|
||||
|
||||
mape_values = [r.overall_mape for r in runs if r.overall_mape is not None]
|
||||
mae_values = [r.overall_mae for r in runs if r.overall_mae is not None]
|
||||
rmse_values = [r.overall_rmse for r in runs if r.overall_rmse is not None]
|
||||
|
||||
avg_mape = sum(mape_values) / len(mape_values) if mape_values else None
|
||||
avg_mae = sum(mae_values) / len(mae_values) if mae_values else None
|
||||
avg_rmse = sum(rmse_values) / len(rmse_values) if rmse_values else None
|
||||
|
||||
# Determine health status
|
||||
health_status = "healthy"
|
||||
if avg_mape and avg_mape > self.MAPE_CRITICAL_THRESHOLD:
|
||||
health_status = "critical"
|
||||
elif avg_mape and avg_mape > self.MAPE_WARNING_THRESHOLD:
|
||||
health_status = "warning"
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"period_days": days,
|
||||
"validation_runs": len(runs),
|
||||
"total_forecasts_evaluated": total_forecasts,
|
||||
"total_forecasts_with_actuals": total_with_actuals,
|
||||
"coverage_percentage": round(
|
||||
(total_with_actuals / total_forecasts * 100) if total_forecasts > 0 else 0, 2
|
||||
),
|
||||
"average_metrics": {
|
||||
"mape": round(avg_mape, 2) if avg_mape else None,
|
||||
"mae": round(avg_mae, 2) if avg_mae else None,
|
||||
"rmse": round(avg_rmse, 2) if avg_rmse else None,
|
||||
"accuracy_percentage": round(100 - avg_mape, 2) if avg_mape else None
|
||||
},
|
||||
"health_status": health_status,
|
||||
"thresholds": {
|
||||
"warning": self.MAPE_WARNING_THRESHOLD,
|
||||
"critical": self.MAPE_CRITICAL_THRESHOLD
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get accuracy summary",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to get accuracy summary: {str(e)}")
|
||||
|
||||
async def detect_performance_degradation(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
lookback_days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Detect if forecast performance is degrading over time
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
lookback_days: Days to analyze for trends
|
||||
|
||||
Returns:
|
||||
Degradation analysis with recommendations
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Detecting performance degradation",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days
|
||||
)
|
||||
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=lookback_days)
|
||||
|
||||
# Get validation runs ordered by time
|
||||
query = (
|
||||
select(ValidationRun)
|
||||
.where(
|
||||
and_(
|
||||
ValidationRun.tenant_id == tenant_id,
|
||||
ValidationRun.status == "completed",
|
||||
ValidationRun.started_at >= start_date,
|
||||
ValidationRun.forecasts_with_actuals > 0
|
||||
)
|
||||
)
|
||||
.order_by(ValidationRun.started_at)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
runs = list(result.scalars().all())
|
||||
|
||||
if len(runs) < self.MIN_SAMPLES_FOR_ALERT:
|
||||
return {
|
||||
"status": "insufficient_data",
|
||||
"message": f"Need at least {self.MIN_SAMPLES_FOR_ALERT} validation runs",
|
||||
"runs_found": len(runs)
|
||||
}
|
||||
|
||||
# Split into first half and second half
|
||||
midpoint = len(runs) // 2
|
||||
first_half = runs[:midpoint]
|
||||
second_half = runs[midpoint:]
|
||||
|
||||
# Calculate average MAPE for each half
|
||||
first_half_mape = sum(
|
||||
r.overall_mape for r in first_half if r.overall_mape
|
||||
) / len([r for r in first_half if r.overall_mape])
|
||||
|
||||
second_half_mape = sum(
|
||||
r.overall_mape for r in second_half if r.overall_mape
|
||||
) / len([r for r in second_half if r.overall_mape])
|
||||
|
||||
mape_change = second_half_mape - first_half_mape
|
||||
mape_change_percentage = (mape_change / first_half_mape * 100) if first_half_mape > 0 else 0
|
||||
|
||||
# Determine if degradation is significant
|
||||
is_degrading = mape_change > self.MAPE_TREND_THRESHOLD
|
||||
severity = "none"
|
||||
|
||||
if is_degrading:
|
||||
if mape_change > self.MAPE_TREND_THRESHOLD * 2:
|
||||
severity = "high"
|
||||
elif mape_change > self.MAPE_TREND_THRESHOLD:
|
||||
severity = "medium"
|
||||
|
||||
# Get products with worst performance
|
||||
poor_products = await self._identify_poor_performers(tenant_id, lookback_days)
|
||||
|
||||
result = {
|
||||
"status": "analyzed",
|
||||
"period_days": lookback_days,
|
||||
"samples_analyzed": len(runs),
|
||||
"is_degrading": is_degrading,
|
||||
"severity": severity,
|
||||
"metrics": {
|
||||
"first_period_mape": round(first_half_mape, 2),
|
||||
"second_period_mape": round(second_half_mape, 2),
|
||||
"mape_change": round(mape_change, 2),
|
||||
"mape_change_percentage": round(mape_change_percentage, 2)
|
||||
},
|
||||
"poor_performers": poor_products,
|
||||
"recommendations": []
|
||||
}
|
||||
|
||||
# Add recommendations
|
||||
if is_degrading:
|
||||
result["recommendations"].append({
|
||||
"action": "retrain_models",
|
||||
"priority": "high" if severity == "high" else "medium",
|
||||
"reason": f"MAPE increased by {abs(mape_change):.1f}% over {lookback_days} days"
|
||||
})
|
||||
|
||||
if poor_products:
|
||||
result["recommendations"].append({
|
||||
"action": "retrain_poor_performers",
|
||||
"priority": "high",
|
||||
"reason": f"{len(poor_products)} products with MAPE > {self.MAPE_CRITICAL_THRESHOLD}%",
|
||||
"products": poor_products[:10] # Top 10 worst
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Performance degradation analysis complete",
|
||||
tenant_id=tenant_id,
|
||||
is_degrading=is_degrading,
|
||||
severity=severity,
|
||||
poor_performers=len(poor_products)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to detect performance degradation",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to detect degradation: {str(e)}")
|
||||
|
||||
async def _identify_poor_performers(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
lookback_days: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Identify products/locations with poor accuracy"""
|
||||
try:
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=lookback_days)
|
||||
|
||||
# Get recent performance metrics grouped by product
|
||||
query = select(
|
||||
ModelPerformanceMetric.inventory_product_id,
|
||||
func.avg(ModelPerformanceMetric.mape).label('avg_mape'),
|
||||
func.avg(ModelPerformanceMetric.mae).label('avg_mae'),
|
||||
func.count(ModelPerformanceMetric.id).label('sample_count')
|
||||
).where(
|
||||
and_(
|
||||
ModelPerformanceMetric.tenant_id == tenant_id,
|
||||
ModelPerformanceMetric.created_at >= start_date,
|
||||
ModelPerformanceMetric.mape.isnot(None)
|
||||
)
|
||||
).group_by(
|
||||
ModelPerformanceMetric.inventory_product_id
|
||||
).having(
|
||||
func.avg(ModelPerformanceMetric.mape) > self.MAPE_CRITICAL_THRESHOLD
|
||||
).order_by(
|
||||
desc(func.avg(ModelPerformanceMetric.mape))
|
||||
).limit(20)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
poor_performers = []
|
||||
|
||||
for row in result.fetchall():
|
||||
poor_performers.append({
|
||||
"inventory_product_id": str(row.inventory_product_id),
|
||||
"avg_mape": round(row.avg_mape, 2),
|
||||
"avg_mae": round(row.avg_mae, 2),
|
||||
"sample_count": row.sample_count,
|
||||
"requires_retraining": True
|
||||
})
|
||||
|
||||
return poor_performers
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to identify poor performers", error=str(e))
|
||||
return []
|
||||
|
||||
async def check_model_age(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
max_age_days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if models are outdated and need retraining
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
max_age_days: Maximum acceptable model age
|
||||
|
||||
Returns:
|
||||
Analysis of model ages
|
||||
"""
|
||||
try:
|
||||
# Get distinct models used in recent forecasts
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=7)
|
||||
|
||||
query = select(
|
||||
Forecast.model_id,
|
||||
Forecast.model_version,
|
||||
Forecast.inventory_product_id,
|
||||
func.max(Forecast.created_at).label('last_used'),
|
||||
func.count(Forecast.id).label('forecast_count')
|
||||
).where(
|
||||
and_(
|
||||
Forecast.tenant_id == tenant_id,
|
||||
Forecast.created_at >= cutoff_date
|
||||
)
|
||||
).group_by(
|
||||
Forecast.model_id,
|
||||
Forecast.model_version,
|
||||
Forecast.inventory_product_id
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
|
||||
models_info = []
|
||||
outdated_count = 0
|
||||
|
||||
for row in result.fetchall():
|
||||
# Check age against training service (would need to query training service)
|
||||
# For now, assume models older than max_age_days need retraining
|
||||
models_info.append({
|
||||
"model_id": row.model_id,
|
||||
"model_version": row.model_version,
|
||||
"inventory_product_id": str(row.inventory_product_id),
|
||||
"last_used": row.last_used.isoformat(),
|
||||
"forecast_count": row.forecast_count
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "analyzed",
|
||||
"models_in_use": len(models_info),
|
||||
"outdated_models": outdated_count,
|
||||
"max_age_days": max_age_days,
|
||||
"models": models_info[:20] # Top 20
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check model age", error=str(e))
|
||||
raise DatabaseError(f"Failed to check model age: {str(e)}")
|
||||
|
||||
async def generate_performance_report(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate comprehensive performance report
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
days: Analysis period
|
||||
|
||||
Returns:
|
||||
Complete performance report with recommendations
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Generating performance report",
|
||||
tenant_id=tenant_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
# Get all analyses
|
||||
summary = await self.get_accuracy_summary(tenant_id, days)
|
||||
degradation = await self.detect_performance_degradation(tenant_id, days)
|
||||
model_age = await self.check_model_age(tenant_id)
|
||||
|
||||
# Compile recommendations
|
||||
all_recommendations = []
|
||||
|
||||
if summary.get("health_status") == "critical":
|
||||
all_recommendations.append({
|
||||
"priority": "critical",
|
||||
"action": "immediate_review",
|
||||
"reason": f"Overall MAPE is {summary['average_metrics']['mape']}%",
|
||||
"details": "Forecast accuracy is critically low"
|
||||
})
|
||||
|
||||
all_recommendations.extend(degradation.get("recommendations", []))
|
||||
|
||||
report = {
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"tenant_id": str(tenant_id),
|
||||
"analysis_period_days": days,
|
||||
"summary": summary,
|
||||
"degradation_analysis": degradation,
|
||||
"model_age_analysis": model_age,
|
||||
"recommendations": all_recommendations,
|
||||
"requires_action": len(all_recommendations) > 0
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Performance report generated",
|
||||
tenant_id=tenant_id,
|
||||
health_status=summary.get("health_status"),
|
||||
recommendations_count=len(all_recommendations)
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to generate performance report",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to generate report: {str(e)}")
|
||||
384
services/forecasting/app/services/retraining_trigger_service.py
Normal file
384
services/forecasting/app/services/retraining_trigger_service.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/retraining_trigger_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Retraining Trigger Service
|
||||
|
||||
Automatically triggers model retraining based on performance metrics,
|
||||
accuracy degradation, or data availability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.services.performance_monitoring_service import PerformanceMonitoringService
|
||||
from shared.clients.training_client import TrainingServiceClient
|
||||
from shared.config.base import BaseServiceSettings
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RetrainingTriggerService:
|
||||
"""Service for triggering automatic model retraining"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
self.performance_service = PerformanceMonitoringService(db_session)
|
||||
|
||||
# Initialize training client
|
||||
config = BaseServiceSettings()
|
||||
self.training_client = TrainingServiceClient(config, calling_service_name="forecasting")
|
||||
|
||||
async def evaluate_and_trigger_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
auto_trigger: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate performance and trigger retraining if needed
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
auto_trigger: Whether to automatically trigger retraining
|
||||
|
||||
Returns:
|
||||
Evaluation results and retraining actions taken
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Evaluating retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=auto_trigger
|
||||
)
|
||||
|
||||
# Generate performance report
|
||||
report = await self.performance_service.generate_performance_report(
|
||||
tenant_id=tenant_id,
|
||||
days=30
|
||||
)
|
||||
|
||||
if not report.get("requires_action"):
|
||||
logger.info(
|
||||
"No retraining required",
|
||||
tenant_id=tenant_id,
|
||||
health_status=report["summary"].get("health_status")
|
||||
)
|
||||
return {
|
||||
"status": "no_action_needed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"health_status": report["summary"].get("health_status"),
|
||||
"report": report
|
||||
}
|
||||
|
||||
# Extract products that need retraining
|
||||
products_to_retrain = []
|
||||
recommendations = report.get("recommendations", [])
|
||||
|
||||
for rec in recommendations:
|
||||
if rec.get("action") == "retrain_poor_performers":
|
||||
products_to_retrain.extend(rec.get("products", []))
|
||||
|
||||
if not products_to_retrain and auto_trigger:
|
||||
# If degradation detected but no specific products, consider retraining all
|
||||
degradation = report.get("degradation_analysis", {})
|
||||
if degradation.get("is_degrading") and degradation.get("severity") in ["high", "medium"]:
|
||||
logger.info(
|
||||
"General degradation detected, considering full retraining",
|
||||
tenant_id=tenant_id,
|
||||
severity=degradation.get("severity")
|
||||
)
|
||||
|
||||
retraining_results = []
|
||||
|
||||
if auto_trigger and products_to_retrain:
|
||||
# Trigger retraining for poor performers
|
||||
for product in products_to_retrain:
|
||||
try:
|
||||
result = await self._trigger_product_retraining(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=uuid.UUID(product["inventory_product_id"]),
|
||||
reason=f"MAPE {product['avg_mape']}% exceeds threshold",
|
||||
priority="high"
|
||||
)
|
||||
retraining_results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger retraining for product",
|
||||
product_id=product["inventory_product_id"],
|
||||
error=str(e)
|
||||
)
|
||||
retraining_results.append({
|
||||
"product_id": product["inventory_product_id"],
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Retraining evaluation complete",
|
||||
tenant_id=tenant_id,
|
||||
products_evaluated=len(products_to_retrain),
|
||||
retraining_triggered=len(retraining_results)
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "evaluated",
|
||||
"tenant_id": str(tenant_id),
|
||||
"requires_action": report.get("requires_action"),
|
||||
"products_needing_retraining": len(products_to_retrain),
|
||||
"retraining_triggered": len(retraining_results),
|
||||
"auto_trigger_enabled": auto_trigger,
|
||||
"retraining_results": retraining_results,
|
||||
"performance_report": report
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to evaluate and trigger retraining",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to evaluate retraining: {str(e)}")
|
||||
|
||||
async def _trigger_product_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
inventory_product_id: uuid.UUID,
|
||||
reason: str,
|
||||
priority: str = "normal"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Trigger retraining for a specific product
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product to retrain
|
||||
reason: Reason for retraining
|
||||
priority: Priority level (low, normal, high)
|
||||
|
||||
Returns:
|
||||
Retraining trigger result
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
reason=reason,
|
||||
priority=priority
|
||||
)
|
||||
|
||||
# Call training service to trigger retraining
|
||||
result = await self.training_client.trigger_retrain(
|
||||
tenant_id=str(tenant_id),
|
||||
inventory_product_id=str(inventory_product_id),
|
||||
reason=reason,
|
||||
priority=priority
|
||||
)
|
||||
|
||||
if result:
|
||||
logger.info(
|
||||
"Retraining triggered successfully",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
training_job_id=result.get("training_job_id")
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "triggered",
|
||||
"product_id": str(inventory_product_id),
|
||||
"training_job_id": result.get("training_job_id"),
|
||||
"reason": reason,
|
||||
"priority": priority,
|
||||
"triggered_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
else:
|
||||
logger.warning(
|
||||
"Retraining trigger returned no result",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id
|
||||
)
|
||||
return {
|
||||
"status": "no_response",
|
||||
"product_id": str(inventory_product_id),
|
||||
"reason": reason
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
error=str(e)
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"product_id": str(inventory_product_id),
|
||||
"error": str(e),
|
||||
"reason": reason
|
||||
}
|
||||
|
||||
async def trigger_bulk_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
product_ids: List[uuid.UUID],
|
||||
reason: str = "Bulk retraining requested"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Trigger retraining for multiple products
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_ids: List of products to retrain
|
||||
reason: Reason for bulk retraining
|
||||
|
||||
Returns:
|
||||
Bulk retraining results
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering bulk retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_count=len(product_ids)
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
for product_id in product_ids:
|
||||
result = await self._trigger_product_retraining(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=product_id,
|
||||
reason=reason,
|
||||
priority="normal"
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
successful = sum(1 for r in results if r["status"] == "triggered")
|
||||
|
||||
logger.info(
|
||||
"Bulk retraining completed",
|
||||
tenant_id=tenant_id,
|
||||
total=len(product_ids),
|
||||
successful=successful,
|
||||
failed=len(product_ids) - successful
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"total_products": len(product_ids),
|
||||
"successful": successful,
|
||||
"failed": len(product_ids) - successful,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Bulk retraining failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Bulk retraining failed: {str(e)}")
|
||||
|
||||
async def check_and_trigger_scheduled_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
max_model_age_days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Check model ages and trigger retraining for outdated models
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
max_model_age_days: Maximum acceptable model age
|
||||
|
||||
Returns:
|
||||
Scheduled retraining results
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Checking for scheduled retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
max_model_age_days=max_model_age_days
|
||||
)
|
||||
|
||||
# Get model age analysis
|
||||
model_age_analysis = await self.performance_service.check_model_age(
|
||||
tenant_id=tenant_id,
|
||||
max_age_days=max_model_age_days
|
||||
)
|
||||
|
||||
outdated_count = model_age_analysis.get("outdated_models", 0)
|
||||
|
||||
if outdated_count == 0:
|
||||
logger.info(
|
||||
"No outdated models found",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return {
|
||||
"status": "no_action_needed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"outdated_models": 0
|
||||
}
|
||||
|
||||
# TODO: Trigger retraining for outdated models
|
||||
# Would need to get list of outdated products from training service
|
||||
|
||||
return {
|
||||
"status": "analyzed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"outdated_models": outdated_count,
|
||||
"message": "Scheduled retraining analysis complete"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Scheduled retraining check failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Scheduled retraining check failed: {str(e)}")
|
||||
|
||||
async def get_retraining_recommendations(
|
||||
self,
|
||||
tenant_id: uuid.UUID
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get retraining recommendations without triggering
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
|
||||
Returns:
|
||||
Recommendations for manual review
|
||||
"""
|
||||
try:
|
||||
# Evaluate without auto-triggering
|
||||
result = await self.evaluate_and_trigger_retraining(
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=False
|
||||
)
|
||||
|
||||
# Extract just the recommendations
|
||||
report = result.get("performance_report", {})
|
||||
recommendations = report.get("recommendations", [])
|
||||
|
||||
return {
|
||||
"tenant_id": str(tenant_id),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"requires_action": result.get("requires_action", False),
|
||||
"recommendations": recommendations,
|
||||
"summary": report.get("summary", {}),
|
||||
"degradation_detected": report.get("degradation_analysis", {}).get("is_degrading", False)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get retraining recommendations",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to get recommendations: {str(e)}")
|
||||
97
services/forecasting/app/services/sales_client.py
Normal file
97
services/forecasting/app/services/sales_client.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/sales_client.py
|
||||
# ================================================================
|
||||
"""
|
||||
Sales Client for Forecasting Service
|
||||
Wrapper around shared sales client with forecasting-specific methods
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from datetime import datetime
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from shared.clients.sales_client import SalesServiceClient
|
||||
from shared.config.base import BaseServiceSettings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SalesClient:
|
||||
"""Client for fetching sales data from sales service"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize sales client"""
|
||||
# Load configuration
|
||||
config = BaseServiceSettings()
|
||||
self.client = SalesServiceClient(config, calling_service_name="forecasting")
|
||||
|
||||
async def get_sales_by_date_range(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
product_id: uuid.UUID = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get sales data for a date range
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start of date range
|
||||
end_date: End of date range
|
||||
product_id: Optional product filter
|
||||
|
||||
Returns:
|
||||
List of sales records
|
||||
"""
|
||||
try:
|
||||
# Convert datetime to ISO format strings
|
||||
start_date_str = start_date.isoformat() if start_date else None
|
||||
end_date_str = end_date.isoformat() if end_date else None
|
||||
product_id_str = str(product_id) if product_id else None
|
||||
|
||||
# Use the paginated method to get all sales data
|
||||
sales_data = await self.client.get_all_sales_data(
|
||||
tenant_id=str(tenant_id),
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
product_id=product_id_str,
|
||||
aggregation="none", # Get raw data without aggregation
|
||||
page_size=1000,
|
||||
max_pages=100
|
||||
)
|
||||
|
||||
if not sales_data:
|
||||
logger.info(
|
||||
"No sales data found for date range",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"Retrieved sales data",
|
||||
tenant_id=tenant_id,
|
||||
records_count=len(sales_data),
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
|
||||
return sales_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to fetch sales data",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
# Return empty list instead of raising to allow validation to continue
|
||||
return []
|
||||
|
||||
async def close(self):
|
||||
"""Close the client connection"""
|
||||
if hasattr(self.client, 'close'):
|
||||
await self.client.close()
|
||||
586
services/forecasting/app/services/validation_service.py
Normal file
586
services/forecasting/app/services/validation_service.py
Normal file
@@ -0,0 +1,586 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/validation_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Forecast Validation Service
|
||||
|
||||
Compares historical forecasts with actual sales data to:
|
||||
1. Calculate accuracy metrics (MAE, MAPE, RMSE, R², accuracy percentage)
|
||||
2. Store performance metrics in the database
|
||||
3. Track validation runs for audit purposes
|
||||
4. Enable continuous model improvement
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, func, Date
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
import math
|
||||
import uuid
|
||||
|
||||
from app.models.forecasts import Forecast
|
||||
from app.models.predictions import ModelPerformanceMetric
|
||||
from app.models.validation_run import ValidationRun
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ValidationService:
|
||||
"""Service for validating forecasts against actual sales data"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
|
||||
async def validate_date_range(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
orchestration_run_id: Optional[uuid.UUID] = None,
|
||||
triggered_by: str = "manual"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate forecasts against actual sales for a date range
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start of validation period
|
||||
end_date: End of validation period
|
||||
orchestration_run_id: Optional link to orchestration run
|
||||
triggered_by: How this validation was triggered (manual, orchestrator, scheduled)
|
||||
|
||||
Returns:
|
||||
Dictionary with validation results and metrics
|
||||
"""
|
||||
validation_run = None
|
||||
|
||||
try:
|
||||
# Create validation run record
|
||||
validation_run = ValidationRun(
|
||||
tenant_id=tenant_id,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
validation_start_date=start_date,
|
||||
validation_end_date=end_date,
|
||||
status="running",
|
||||
triggered_by=triggered_by,
|
||||
execution_mode="batch"
|
||||
)
|
||||
self.db.add(validation_run)
|
||||
await self.db.flush()
|
||||
|
||||
logger.info(
|
||||
"Starting forecast validation",
|
||||
validation_run_id=validation_run.id,
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
|
||||
# Fetch forecasts with matching sales data
|
||||
forecasts_with_sales = await self._fetch_forecasts_with_sales(
|
||||
tenant_id, start_date, end_date
|
||||
)
|
||||
|
||||
if not forecasts_with_sales:
|
||||
logger.warning(
|
||||
"No forecasts with matching sales data found",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
validation_run.status = "completed"
|
||||
validation_run.completed_at = datetime.now(timezone.utc)
|
||||
validation_run.duration_seconds = (
|
||||
validation_run.completed_at - validation_run.started_at
|
||||
).total_seconds()
|
||||
await self.db.commit()
|
||||
|
||||
return {
|
||||
"validation_run_id": str(validation_run.id),
|
||||
"status": "completed",
|
||||
"message": "No forecasts with matching sales data found",
|
||||
"forecasts_evaluated": 0,
|
||||
"metrics_created": 0
|
||||
}
|
||||
|
||||
# Calculate metrics and create performance records
|
||||
metrics_results = await self._calculate_and_store_metrics(
|
||||
forecasts_with_sales, validation_run.id
|
||||
)
|
||||
|
||||
# Update validation run with results
|
||||
validation_run.total_forecasts_evaluated = metrics_results["total_evaluated"]
|
||||
validation_run.forecasts_with_actuals = metrics_results["with_actuals"]
|
||||
validation_run.forecasts_without_actuals = metrics_results["without_actuals"]
|
||||
validation_run.overall_mae = metrics_results["overall_mae"]
|
||||
validation_run.overall_mape = metrics_results["overall_mape"]
|
||||
validation_run.overall_rmse = metrics_results["overall_rmse"]
|
||||
validation_run.overall_r2_score = metrics_results["overall_r2_score"]
|
||||
validation_run.overall_accuracy_percentage = metrics_results["overall_accuracy_percentage"]
|
||||
validation_run.total_predicted_demand = metrics_results["total_predicted"]
|
||||
validation_run.total_actual_demand = metrics_results["total_actual"]
|
||||
validation_run.metrics_by_product = metrics_results["metrics_by_product"]
|
||||
validation_run.metrics_by_location = metrics_results["metrics_by_location"]
|
||||
validation_run.metrics_records_created = metrics_results["metrics_created"]
|
||||
validation_run.status = "completed"
|
||||
validation_run.completed_at = datetime.now(timezone.utc)
|
||||
validation_run.duration_seconds = (
|
||||
validation_run.completed_at - validation_run.started_at
|
||||
).total_seconds()
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(
|
||||
"Forecast validation completed successfully",
|
||||
validation_run_id=validation_run.id,
|
||||
forecasts_evaluated=validation_run.total_forecasts_evaluated,
|
||||
metrics_created=validation_run.metrics_records_created,
|
||||
overall_mape=validation_run.overall_mape,
|
||||
duration_seconds=validation_run.duration_seconds
|
||||
)
|
||||
|
||||
# Extract poor accuracy products (MAPE > 30%)
|
||||
poor_accuracy_products = []
|
||||
if validation_run.metrics_by_product:
|
||||
for product_id, product_metrics in validation_run.metrics_by_product.items():
|
||||
if product_metrics.get("mape", 0) > 30:
|
||||
poor_accuracy_products.append({
|
||||
"product_id": product_id,
|
||||
"mape": product_metrics.get("mape"),
|
||||
"mae": product_metrics.get("mae"),
|
||||
"accuracy_percentage": product_metrics.get("accuracy_percentage")
|
||||
})
|
||||
|
||||
return {
|
||||
"validation_run_id": str(validation_run.id),
|
||||
"status": "completed",
|
||||
"forecasts_evaluated": validation_run.total_forecasts_evaluated,
|
||||
"forecasts_with_actuals": validation_run.forecasts_with_actuals,
|
||||
"forecasts_without_actuals": validation_run.forecasts_without_actuals,
|
||||
"metrics_created": validation_run.metrics_records_created,
|
||||
"overall_metrics": {
|
||||
"mae": validation_run.overall_mae,
|
||||
"mape": validation_run.overall_mape,
|
||||
"rmse": validation_run.overall_rmse,
|
||||
"r2_score": validation_run.overall_r2_score,
|
||||
"accuracy_percentage": validation_run.overall_accuracy_percentage
|
||||
},
|
||||
"total_predicted_demand": validation_run.total_predicted_demand,
|
||||
"total_actual_demand": validation_run.total_actual_demand,
|
||||
"duration_seconds": validation_run.duration_seconds,
|
||||
"poor_accuracy_products": poor_accuracy_products,
|
||||
"metrics_by_product": validation_run.metrics_by_product,
|
||||
"metrics_by_location": validation_run.metrics_by_location
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Forecast validation failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
|
||||
if validation_run:
|
||||
validation_run.status = "failed"
|
||||
validation_run.error_message = str(e)
|
||||
validation_run.error_details = {"error_type": type(e).__name__}
|
||||
validation_run.completed_at = datetime.now(timezone.utc)
|
||||
validation_run.duration_seconds = (
|
||||
validation_run.completed_at - validation_run.started_at
|
||||
).total_seconds()
|
||||
await self.db.commit()
|
||||
|
||||
raise DatabaseError(f"Forecast validation failed: {str(e)}")
|
||||
|
||||
async def validate_yesterday(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
orchestration_run_id: Optional[uuid.UUID] = None,
|
||||
triggered_by: str = "orchestrator"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convenience method to validate yesterday's forecasts
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
orchestration_run_id: Optional link to orchestration run
|
||||
triggered_by: How this validation was triggered
|
||||
|
||||
Returns:
|
||||
Dictionary with validation results
|
||||
"""
|
||||
yesterday = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
start_date = yesterday.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = yesterday.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
|
||||
return await self.validate_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
triggered_by=triggered_by
|
||||
)
|
||||
|
||||
async def _fetch_forecasts_with_sales(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: datetime,
|
||||
end_date: datetime
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch forecasts with their corresponding actual sales data
|
||||
|
||||
Returns list of dictionaries containing forecast and sales data
|
||||
"""
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from app.services.sales_client import SalesClient
|
||||
|
||||
# Query to get all forecasts in the date range
|
||||
query = select(Forecast).where(
|
||||
and_(
|
||||
Forecast.tenant_id == tenant_id,
|
||||
func.cast(Forecast.forecast_date, Date) >= start_date.date(),
|
||||
func.cast(Forecast.forecast_date, Date) <= end_date.date()
|
||||
)
|
||||
).order_by(Forecast.forecast_date, Forecast.inventory_product_id)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
forecasts = result.scalars().all()
|
||||
|
||||
if not forecasts:
|
||||
return []
|
||||
|
||||
# Fetch actual sales data from sales service
|
||||
sales_client = SalesClient()
|
||||
sales_data = await sales_client.get_sales_by_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
# Create lookup dict: (product_id, date) -> sales quantity
|
||||
sales_lookup = {}
|
||||
for sale in sales_data:
|
||||
sale_date = sale['date']
|
||||
if isinstance(sale_date, str):
|
||||
sale_date = datetime.fromisoformat(sale_date.replace('Z', '+00:00'))
|
||||
|
||||
key = (str(sale['inventory_product_id']), sale_date.date())
|
||||
|
||||
# Sum quantities if multiple sales records for same product/date
|
||||
if key in sales_lookup:
|
||||
sales_lookup[key]['quantity_sold'] += sale['quantity_sold']
|
||||
else:
|
||||
sales_lookup[key] = sale
|
||||
|
||||
# Match forecasts with sales data
|
||||
forecasts_with_sales = []
|
||||
for forecast in forecasts:
|
||||
forecast_date = forecast.forecast_date.date() if hasattr(forecast.forecast_date, 'date') else forecast.forecast_date
|
||||
key = (str(forecast.inventory_product_id), forecast_date)
|
||||
|
||||
sales_record = sales_lookup.get(key)
|
||||
|
||||
forecasts_with_sales.append({
|
||||
"forecast_id": forecast.id,
|
||||
"tenant_id": forecast.tenant_id,
|
||||
"inventory_product_id": forecast.inventory_product_id,
|
||||
"product_name": forecast.product_name,
|
||||
"location": forecast.location,
|
||||
"forecast_date": forecast.forecast_date,
|
||||
"predicted_demand": forecast.predicted_demand,
|
||||
"confidence_lower": forecast.confidence_lower,
|
||||
"confidence_upper": forecast.confidence_upper,
|
||||
"model_id": forecast.model_id,
|
||||
"model_version": forecast.model_version,
|
||||
"algorithm": forecast.algorithm,
|
||||
"created_at": forecast.created_at,
|
||||
"actual_sales": sales_record['quantity_sold'] if sales_record else None,
|
||||
"has_actual_data": sales_record is not None
|
||||
})
|
||||
|
||||
return forecasts_with_sales
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to fetch forecasts with sales data",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to fetch forecast and sales data: {str(e)}")
|
||||
|
||||
async def _calculate_and_store_metrics(
|
||||
self,
|
||||
forecasts_with_sales: List[Dict[str, Any]],
|
||||
validation_run_id: uuid.UUID
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate accuracy metrics and store them in the database
|
||||
|
||||
Returns summary of metrics calculated
|
||||
"""
|
||||
# Separate forecasts with and without actual data
|
||||
forecasts_with_actuals = [f for f in forecasts_with_sales if f["has_actual_data"]]
|
||||
forecasts_without_actuals = [f for f in forecasts_with_sales if not f["has_actual_data"]]
|
||||
|
||||
if not forecasts_with_actuals:
|
||||
return {
|
||||
"total_evaluated": len(forecasts_with_sales),
|
||||
"with_actuals": 0,
|
||||
"without_actuals": len(forecasts_without_actuals),
|
||||
"metrics_created": 0,
|
||||
"overall_mae": None,
|
||||
"overall_mape": None,
|
||||
"overall_rmse": None,
|
||||
"overall_r2_score": None,
|
||||
"overall_accuracy_percentage": None,
|
||||
"total_predicted": 0.0,
|
||||
"total_actual": 0.0,
|
||||
"metrics_by_product": {},
|
||||
"metrics_by_location": {}
|
||||
}
|
||||
|
||||
# Calculate individual metrics and prepare for bulk insert
|
||||
performance_metrics = []
|
||||
errors = []
|
||||
|
||||
for forecast in forecasts_with_actuals:
|
||||
predicted = forecast["predicted_demand"]
|
||||
actual = forecast["actual_sales"]
|
||||
|
||||
# Calculate error
|
||||
error = abs(predicted - actual)
|
||||
errors.append(error)
|
||||
|
||||
# Calculate percentage error (avoid division by zero)
|
||||
percentage_error = (error / actual * 100) if actual > 0 else 0.0
|
||||
|
||||
# Calculate individual metrics
|
||||
mae = error
|
||||
mape = percentage_error
|
||||
rmse = error # Will be squared and averaged later
|
||||
|
||||
# Calculate accuracy percentage (100% - MAPE, capped at 0)
|
||||
accuracy_percentage = max(0.0, 100.0 - mape)
|
||||
|
||||
# Create performance metric record
|
||||
metric = ModelPerformanceMetric(
|
||||
model_id=uuid.UUID(forecast["model_id"]) if isinstance(forecast["model_id"], str) else forecast["model_id"],
|
||||
tenant_id=forecast["tenant_id"],
|
||||
inventory_product_id=forecast["inventory_product_id"],
|
||||
mae=mae,
|
||||
mape=mape,
|
||||
rmse=rmse,
|
||||
accuracy_score=accuracy_percentage / 100.0, # Store as 0-1 scale
|
||||
evaluation_date=forecast["forecast_date"],
|
||||
evaluation_period_start=forecast["forecast_date"],
|
||||
evaluation_period_end=forecast["forecast_date"],
|
||||
sample_size=1
|
||||
)
|
||||
performance_metrics.append(metric)
|
||||
|
||||
# Bulk insert all performance metrics
|
||||
if performance_metrics:
|
||||
self.db.add_all(performance_metrics)
|
||||
await self.db.flush()
|
||||
|
||||
# Calculate overall metrics
|
||||
overall_metrics = self._calculate_overall_metrics(forecasts_with_actuals)
|
||||
|
||||
# Calculate metrics by product
|
||||
metrics_by_product = self._calculate_metrics_by_dimension(
|
||||
forecasts_with_actuals, "inventory_product_id"
|
||||
)
|
||||
|
||||
# Calculate metrics by location
|
||||
metrics_by_location = self._calculate_metrics_by_dimension(
|
||||
forecasts_with_actuals, "location"
|
||||
)
|
||||
|
||||
return {
|
||||
"total_evaluated": len(forecasts_with_sales),
|
||||
"with_actuals": len(forecasts_with_actuals),
|
||||
"without_actuals": len(forecasts_without_actuals),
|
||||
"metrics_created": len(performance_metrics),
|
||||
"overall_mae": overall_metrics["mae"],
|
||||
"overall_mape": overall_metrics["mape"],
|
||||
"overall_rmse": overall_metrics["rmse"],
|
||||
"overall_r2_score": overall_metrics["r2_score"],
|
||||
"overall_accuracy_percentage": overall_metrics["accuracy_percentage"],
|
||||
"total_predicted": overall_metrics["total_predicted"],
|
||||
"total_actual": overall_metrics["total_actual"],
|
||||
"metrics_by_product": metrics_by_product,
|
||||
"metrics_by_location": metrics_by_location
|
||||
}
|
||||
|
||||
def _calculate_overall_metrics(self, forecasts: List[Dict[str, Any]]) -> Dict[str, float]:
|
||||
"""Calculate aggregated metrics across all forecasts"""
|
||||
if not forecasts:
|
||||
return {
|
||||
"mae": None, "mape": None, "rmse": None,
|
||||
"r2_score": None, "accuracy_percentage": None,
|
||||
"total_predicted": 0.0, "total_actual": 0.0
|
||||
}
|
||||
|
||||
predicted_values = [f["predicted_demand"] for f in forecasts]
|
||||
actual_values = [f["actual_sales"] for f in forecasts]
|
||||
|
||||
# MAE: Mean Absolute Error
|
||||
mae = sum(abs(p - a) for p, a in zip(predicted_values, actual_values)) / len(forecasts)
|
||||
|
||||
# MAPE: Mean Absolute Percentage Error (handle division by zero)
|
||||
mape_values = [
|
||||
abs(p - a) / a * 100 if a > 0 else 0.0
|
||||
for p, a in zip(predicted_values, actual_values)
|
||||
]
|
||||
mape = sum(mape_values) / len(mape_values)
|
||||
|
||||
# RMSE: Root Mean Square Error
|
||||
squared_errors = [(p - a) ** 2 for p, a in zip(predicted_values, actual_values)]
|
||||
rmse = math.sqrt(sum(squared_errors) / len(squared_errors))
|
||||
|
||||
# R² Score (coefficient of determination)
|
||||
mean_actual = sum(actual_values) / len(actual_values)
|
||||
ss_total = sum((a - mean_actual) ** 2 for a in actual_values)
|
||||
ss_residual = sum((a - p) ** 2 for a, p in zip(actual_values, predicted_values))
|
||||
r2_score = 1 - (ss_residual / ss_total) if ss_total > 0 else 0.0
|
||||
|
||||
# Accuracy percentage (100% - MAPE)
|
||||
accuracy_percentage = max(0.0, 100.0 - mape)
|
||||
|
||||
return {
|
||||
"mae": round(mae, 2),
|
||||
"mape": round(mape, 2),
|
||||
"rmse": round(rmse, 2),
|
||||
"r2_score": round(r2_score, 4),
|
||||
"accuracy_percentage": round(accuracy_percentage, 2),
|
||||
"total_predicted": round(sum(predicted_values), 2),
|
||||
"total_actual": round(sum(actual_values), 2)
|
||||
}
|
||||
|
||||
def _calculate_metrics_by_dimension(
|
||||
self,
|
||||
forecasts: List[Dict[str, Any]],
|
||||
dimension_key: str
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""Calculate metrics grouped by a dimension (product_id or location)"""
|
||||
dimension_groups = {}
|
||||
|
||||
# Group forecasts by dimension
|
||||
for forecast in forecasts:
|
||||
key = str(forecast[dimension_key])
|
||||
if key not in dimension_groups:
|
||||
dimension_groups[key] = []
|
||||
dimension_groups[key].append(forecast)
|
||||
|
||||
# Calculate metrics for each group
|
||||
metrics_by_dimension = {}
|
||||
for key, group_forecasts in dimension_groups.items():
|
||||
metrics = self._calculate_overall_metrics(group_forecasts)
|
||||
metrics_by_dimension[key] = {
|
||||
"count": len(group_forecasts),
|
||||
"mae": metrics["mae"],
|
||||
"mape": metrics["mape"],
|
||||
"rmse": metrics["rmse"],
|
||||
"accuracy_percentage": metrics["accuracy_percentage"],
|
||||
"total_predicted": metrics["total_predicted"],
|
||||
"total_actual": metrics["total_actual"]
|
||||
}
|
||||
|
||||
return metrics_by_dimension
|
||||
|
||||
async def get_validation_run(self, validation_run_id: uuid.UUID) -> Optional[ValidationRun]:
|
||||
"""Get a validation run by ID"""
|
||||
try:
|
||||
query = select(ValidationRun).where(ValidationRun.id == validation_run_id)
|
||||
result = await self.db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Failed to get validation run", validation_run_id=validation_run_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to get validation run: {str(e)}")
|
||||
|
||||
async def get_validation_runs_by_tenant(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
limit: int = 50,
|
||||
skip: int = 0
|
||||
) -> List[ValidationRun]:
|
||||
"""Get validation runs for a tenant"""
|
||||
try:
|
||||
query = (
|
||||
select(ValidationRun)
|
||||
.where(ValidationRun.tenant_id == tenant_id)
|
||||
.order_by(ValidationRun.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(skip)
|
||||
)
|
||||
result = await self.db.execute(query)
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error("Failed to get validation runs", tenant_id=tenant_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to get validation runs: {str(e)}")
|
||||
|
||||
async def get_accuracy_trends(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get accuracy trends over time"""
|
||||
try:
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
|
||||
query = (
|
||||
select(ValidationRun)
|
||||
.where(
|
||||
and_(
|
||||
ValidationRun.tenant_id == tenant_id,
|
||||
ValidationRun.status == "completed",
|
||||
ValidationRun.created_at >= start_date
|
||||
)
|
||||
)
|
||||
.order_by(ValidationRun.created_at)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
runs = result.scalars().all()
|
||||
|
||||
if not runs:
|
||||
return {
|
||||
"period_days": days,
|
||||
"total_runs": 0,
|
||||
"trends": []
|
||||
}
|
||||
|
||||
trends = [
|
||||
{
|
||||
"date": run.validation_start_date.isoformat(),
|
||||
"mae": run.overall_mae,
|
||||
"mape": run.overall_mape,
|
||||
"rmse": run.overall_rmse,
|
||||
"accuracy_percentage": run.overall_accuracy_percentage,
|
||||
"forecasts_evaluated": run.total_forecasts_evaluated,
|
||||
"forecasts_with_actuals": run.forecasts_with_actuals
|
||||
}
|
||||
for run in runs
|
||||
]
|
||||
|
||||
# Calculate averages
|
||||
valid_runs = [r for r in runs if r.overall_mape is not None]
|
||||
avg_mape = sum(r.overall_mape for r in valid_runs) / len(valid_runs) if valid_runs else None
|
||||
avg_accuracy = sum(r.overall_accuracy_percentage for r in valid_runs) / len(valid_runs) if valid_runs else None
|
||||
|
||||
return {
|
||||
"period_days": days,
|
||||
"total_runs": len(runs),
|
||||
"average_mape": round(avg_mape, 2) if avg_mape else None,
|
||||
"average_accuracy": round(avg_accuracy, 2) if avg_accuracy else None,
|
||||
"trends": trends
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get accuracy trends", tenant_id=tenant_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to get accuracy trends: {str(e)}")
|
||||
Reference in New Issue
Block a user