Improve backend
This commit is contained in:
@@ -6,10 +6,20 @@ HTTP endpoints for demand forecasting and prediction operations
|
||||
from .forecasts import router as forecasts_router
|
||||
from .forecasting_operations import router as forecasting_operations_router
|
||||
from .analytics import router as analytics_router
|
||||
from .validation import router as validation_router
|
||||
from .historical_validation import router as historical_validation_router
|
||||
from .webhooks import router as webhooks_router
|
||||
from .performance_monitoring import router as performance_monitoring_router
|
||||
from .retraining import router as retraining_router
|
||||
|
||||
|
||||
__all__ = [
|
||||
"forecasts_router",
|
||||
"forecasting_operations_router",
|
||||
"analytics_router",
|
||||
"validation_router",
|
||||
"historical_validation_router",
|
||||
"webhooks_router",
|
||||
"performance_monitoring_router",
|
||||
"retraining_router",
|
||||
]
|
||||
304
services/forecasting/app/api/historical_validation.py
Normal file
304
services/forecasting/app/api/historical_validation.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/historical_validation.py
|
||||
# ================================================================
|
||||
"""
|
||||
Historical Validation API - Backfill validation for late-arriving sales data
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
||||
from typing import Dict, Any, List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import date
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.services.historical_validation_service import HistoricalValidationService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["historical-validation"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Schemas
|
||||
# ================================================================
|
||||
|
||||
class DetectGapsRequest(BaseModel):
|
||||
"""Request model for gap detection"""
|
||||
lookback_days: int = Field(default=90, ge=1, le=365, description="Days to look back")
|
||||
|
||||
|
||||
class BackfillRequest(BaseModel):
|
||||
"""Request model for manual backfill"""
|
||||
start_date: date = Field(..., description="Start date for backfill")
|
||||
end_date: date = Field(..., description="End date for backfill")
|
||||
|
||||
|
||||
class SalesDataUpdateRequest(BaseModel):
|
||||
"""Request model for registering sales data update"""
|
||||
start_date: date = Field(..., description="Start date of updated data")
|
||||
end_date: date = Field(..., description="End date of updated data")
|
||||
records_affected: int = Field(..., ge=0, description="Number of records affected")
|
||||
update_source: str = Field(default="import", description="Source of update")
|
||||
import_job_id: Optional[str] = Field(None, description="Import job ID if applicable")
|
||||
auto_trigger_validation: bool = Field(default=True, description="Auto-trigger validation")
|
||||
|
||||
|
||||
class AutoBackfillRequest(BaseModel):
|
||||
"""Request model for automatic backfill"""
|
||||
lookback_days: int = Field(default=90, ge=1, le=365, description="Days to look back")
|
||||
max_gaps_to_process: int = Field(default=10, ge=1, le=50, description="Max gaps to process")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/detect-gaps"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def detect_validation_gaps(
|
||||
request: DetectGapsRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Detect date ranges where forecasts exist but haven't been validated yet
|
||||
|
||||
Returns list of gap periods that need validation backfill.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Detecting validation gaps",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=request.lookback_days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
gaps = await service.detect_validation_gaps(
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=request.lookback_days
|
||||
)
|
||||
|
||||
return {
|
||||
"gaps_found": len(gaps),
|
||||
"lookback_days": request.lookback_days,
|
||||
"gaps": [
|
||||
{
|
||||
"start_date": gap["start_date"].isoformat(),
|
||||
"end_date": gap["end_date"].isoformat(),
|
||||
"days_count": gap["days_count"]
|
||||
}
|
||||
for gap in gaps
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to detect validation gaps",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to detect validation gaps: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/backfill"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def backfill_validation(
|
||||
request: BackfillRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Manually trigger validation backfill for a specific date range
|
||||
|
||||
Validates forecasts against sales data for historical periods.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Manual validation backfill requested",
|
||||
tenant_id=tenant_id,
|
||||
start_date=request.start_date.isoformat(),
|
||||
end_date=request.end_date.isoformat(),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
result = await service.backfill_validation(
|
||||
tenant_id=tenant_id,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
triggered_by="manual"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to backfill validation",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to backfill validation: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/auto-backfill"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def auto_backfill_validation_gaps(
|
||||
request: AutoBackfillRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Automatically detect and backfill validation gaps
|
||||
|
||||
Finds all date ranges with missing validations and processes them.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Auto backfill requested",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=request.lookback_days,
|
||||
max_gaps=request.max_gaps_to_process,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
result = await service.auto_backfill_gaps(
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=request.lookback_days,
|
||||
max_gaps_to_process=request.max_gaps_to_process
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to auto backfill",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to auto backfill: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/register-sales-update"),
|
||||
status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def register_sales_data_update(
|
||||
request: SalesDataUpdateRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Register a sales data update and optionally trigger validation
|
||||
|
||||
Call this endpoint after importing historical sales data to automatically
|
||||
trigger validation for the affected date range.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Registering sales data update",
|
||||
tenant_id=tenant_id,
|
||||
date_range=f"{request.start_date} to {request.end_date}",
|
||||
records_affected=request.records_affected,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
result = await service.register_sales_data_update(
|
||||
tenant_id=tenant_id,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
records_affected=request.records_affected,
|
||||
update_source=request.update_source,
|
||||
import_job_id=request.import_job_id,
|
||||
auto_trigger_validation=request.auto_trigger_validation
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to register sales data update",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to register sales data update: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/pending"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_pending_validations(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
limit: int = Query(50, ge=1, le=100, description="Number of records to return"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get pending sales data updates awaiting validation
|
||||
|
||||
Returns list of sales data updates that have been registered
|
||||
but not yet validated.
|
||||
"""
|
||||
try:
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
pending = await service.get_pending_validations(
|
||||
tenant_id=tenant_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"pending_count": len(pending),
|
||||
"pending_validations": [record.to_dict() for record in pending]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get pending validations",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get pending validations: {str(e)}"
|
||||
)
|
||||
287
services/forecasting/app/api/performance_monitoring.py
Normal file
287
services/forecasting/app/api/performance_monitoring.py
Normal file
@@ -0,0 +1,287 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/performance_monitoring.py
|
||||
# ================================================================
|
||||
"""
|
||||
Performance Monitoring API - Track and analyze forecast accuracy over time
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
||||
from typing import Dict, Any
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.services.performance_monitoring_service import PerformanceMonitoringService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["performance-monitoring"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Schemas
|
||||
# ================================================================
|
||||
|
||||
class AccuracySummaryRequest(BaseModel):
|
||||
"""Request model for accuracy summary"""
|
||||
days: int = Field(default=30, ge=1, le=365, description="Analysis period in days")
|
||||
|
||||
|
||||
class DegradationAnalysisRequest(BaseModel):
|
||||
"""Request model for degradation analysis"""
|
||||
lookback_days: int = Field(default=30, ge=7, le=365, description="Days to analyze")
|
||||
|
||||
|
||||
class ModelAgeCheckRequest(BaseModel):
|
||||
"""Request model for model age check"""
|
||||
max_age_days: int = Field(default=30, ge=1, le=90, description="Max acceptable model age")
|
||||
|
||||
|
||||
class PerformanceReportRequest(BaseModel):
|
||||
"""Request model for comprehensive performance report"""
|
||||
days: int = Field(default=30, ge=1, le=365, description="Analysis period in days")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("monitoring/accuracy-summary"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_accuracy_summary(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
days: int = Query(30, ge=1, le=365, description="Analysis period in days"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get forecast accuracy summary for recent period
|
||||
|
||||
Returns overall metrics, validation coverage, and health status.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Getting accuracy summary",
|
||||
tenant_id=tenant_id,
|
||||
days=days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
summary = await service.get_accuracy_summary(
|
||||
tenant_id=tenant_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get accuracy summary",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get accuracy summary: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("monitoring/degradation-analysis"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def analyze_performance_degradation(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
lookback_days: int = Query(30, ge=7, le=365, description="Days to analyze"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Detect if forecast performance is degrading over time
|
||||
|
||||
Compares first half vs second half of period and identifies poor performers.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Analyzing performance degradation",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
analysis = await service.detect_performance_degradation(
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days
|
||||
)
|
||||
|
||||
return analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to analyze degradation",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to analyze degradation: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("monitoring/model-age"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def check_model_age(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
max_age_days: int = Query(30, ge=1, le=90, description="Max acceptable model age"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Check if models are outdated and need retraining
|
||||
|
||||
Returns models in use and identifies those needing updates.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Checking model age",
|
||||
tenant_id=tenant_id,
|
||||
max_age_days=max_age_days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
analysis = await service.check_model_age(
|
||||
tenant_id=tenant_id,
|
||||
max_age_days=max_age_days
|
||||
)
|
||||
|
||||
return analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to check model age",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to check model age: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("monitoring/performance-report"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def generate_performance_report(
|
||||
request: PerformanceReportRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Generate comprehensive performance report
|
||||
|
||||
Combines accuracy summary, degradation analysis, and model age check
|
||||
with actionable recommendations.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Generating performance report",
|
||||
tenant_id=tenant_id,
|
||||
days=request.days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
report = await service.generate_performance_report(
|
||||
tenant_id=tenant_id,
|
||||
days=request.days
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to generate performance report",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to generate performance report: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("monitoring/health"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_health_status(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get quick health status for dashboards
|
||||
|
||||
Returns simplified health metrics for UI display.
|
||||
"""
|
||||
try:
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
# Get 7-day summary for quick health check
|
||||
summary = await service.get_accuracy_summary(
|
||||
tenant_id=tenant_id,
|
||||
days=7
|
||||
)
|
||||
|
||||
if summary.get("status") == "no_data":
|
||||
return {
|
||||
"status": "unknown",
|
||||
"message": "No recent validation data available",
|
||||
"health_status": "unknown"
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"health_status": summary.get("health_status"),
|
||||
"current_mape": summary["average_metrics"].get("mape"),
|
||||
"accuracy_percentage": summary["average_metrics"].get("accuracy_percentage"),
|
||||
"validation_coverage": summary.get("coverage_percentage"),
|
||||
"last_7_days": {
|
||||
"validation_runs": summary.get("validation_runs"),
|
||||
"forecasts_evaluated": summary.get("total_forecasts_evaluated")
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get health status",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get health status: {str(e)}"
|
||||
)
|
||||
297
services/forecasting/app/api/retraining.py
Normal file
297
services/forecasting/app/api/retraining.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/retraining.py
|
||||
# ================================================================
|
||||
"""
|
||||
Retraining API - Trigger and manage model retraining based on performance
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
||||
from typing import Dict, Any, List
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.services.retraining_trigger_service import RetrainingTriggerService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["retraining"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Schemas
|
||||
# ================================================================
|
||||
|
||||
class EvaluateRetrainingRequest(BaseModel):
|
||||
"""Request model for retraining evaluation"""
|
||||
auto_trigger: bool = Field(
|
||||
default=False,
|
||||
description="Automatically trigger retraining for poor performers"
|
||||
)
|
||||
|
||||
|
||||
class TriggerProductRetrainingRequest(BaseModel):
|
||||
"""Request model for single product retraining"""
|
||||
inventory_product_id: UUID = Field(..., description="Product to retrain")
|
||||
reason: str = Field(..., description="Reason for retraining")
|
||||
priority: str = Field(
|
||||
default="normal",
|
||||
description="Priority level: low, normal, high"
|
||||
)
|
||||
|
||||
|
||||
class TriggerBulkRetrainingRequest(BaseModel):
|
||||
"""Request model for bulk retraining"""
|
||||
product_ids: List[UUID] = Field(..., description="List of products to retrain")
|
||||
reason: str = Field(
|
||||
default="Bulk retraining requested",
|
||||
description="Reason for bulk retraining"
|
||||
)
|
||||
|
||||
|
||||
class ScheduledRetrainingCheckRequest(BaseModel):
|
||||
"""Request model for scheduled retraining check"""
|
||||
max_model_age_days: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=90,
|
||||
description="Maximum acceptable model age"
|
||||
)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/evaluate"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def evaluate_retraining_needs(
|
||||
request: EvaluateRetrainingRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Evaluate performance and optionally trigger retraining
|
||||
|
||||
Analyzes 30-day performance and identifies products needing retraining.
|
||||
If auto_trigger=true, automatically triggers retraining for poor performers.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Evaluating retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=request.auto_trigger,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service.evaluate_and_trigger_retraining(
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=request.auto_trigger
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to evaluate retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to evaluate retraining: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/trigger-product"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def trigger_product_retraining(
|
||||
request: TriggerProductRetrainingRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Trigger retraining for a specific product
|
||||
|
||||
Manually trigger model retraining for a single product.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=request.inventory_product_id,
|
||||
reason=request.reason,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service._trigger_product_retraining(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
reason=request.reason,
|
||||
priority=request.priority
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=request.inventory_product_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to trigger retraining: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/trigger-bulk"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def trigger_bulk_retraining(
|
||||
request: TriggerBulkRetrainingRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Trigger retraining for multiple products
|
||||
|
||||
Bulk retraining operation for multiple products at once.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering bulk retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_count=len(request.product_ids),
|
||||
reason=request.reason,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service.trigger_bulk_retraining(
|
||||
tenant_id=tenant_id,
|
||||
product_ids=request.product_ids,
|
||||
reason=request.reason
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger bulk retraining",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to trigger bulk retraining: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("retraining/recommendations"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_retraining_recommendations(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get retraining recommendations without triggering
|
||||
|
||||
Returns recommendations for manual review and decision-making.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Getting retraining recommendations",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
recommendations = await service.get_retraining_recommendations(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return recommendations
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get recommendations",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get recommendations: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/check-scheduled"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def check_scheduled_retraining(
|
||||
request: ScheduledRetrainingCheckRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Check for models needing scheduled retraining based on age
|
||||
|
||||
Identifies models that haven't been updated in max_model_age_days.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Checking scheduled retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
max_model_age_days=request.max_model_age_days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service.check_and_trigger_scheduled_retraining(
|
||||
tenant_id=tenant_id,
|
||||
max_model_age_days=request.max_model_age_days
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to check scheduled retraining",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to check scheduled retraining: {str(e)}"
|
||||
)
|
||||
346
services/forecasting/app/api/validation.py
Normal file
346
services/forecasting/app/api/validation.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/validation.py
|
||||
# ================================================================
|
||||
"""
|
||||
Validation API - Forecast validation endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
||||
from typing import Dict, Any, List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.services.validation_service import ValidationService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["validation"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Schemas
|
||||
# ================================================================
|
||||
|
||||
class ValidationRequest(BaseModel):
|
||||
"""Request model for validation"""
|
||||
start_date: datetime = Field(..., description="Start date for validation period")
|
||||
end_date: datetime = Field(..., description="End date for validation period")
|
||||
orchestration_run_id: Optional[UUID] = Field(None, description="Optional orchestration run ID")
|
||||
triggered_by: str = Field(default="manual", description="Trigger source")
|
||||
|
||||
|
||||
class ValidationResponse(BaseModel):
|
||||
"""Response model for validation results"""
|
||||
validation_run_id: str
|
||||
status: str
|
||||
forecasts_evaluated: int
|
||||
forecasts_with_actuals: int
|
||||
forecasts_without_actuals: int
|
||||
metrics_created: int
|
||||
overall_metrics: Optional[Dict[str, float]] = None
|
||||
total_predicted_demand: Optional[float] = None
|
||||
total_actual_demand: Optional[float] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class ValidationRunResponse(BaseModel):
|
||||
"""Response model for validation run details"""
|
||||
id: str
|
||||
tenant_id: str
|
||||
orchestration_run_id: Optional[str]
|
||||
validation_start_date: str
|
||||
validation_end_date: str
|
||||
started_at: str
|
||||
completed_at: Optional[str]
|
||||
duration_seconds: Optional[float]
|
||||
status: str
|
||||
total_forecasts_evaluated: int
|
||||
forecasts_with_actuals: int
|
||||
forecasts_without_actuals: int
|
||||
overall_mae: Optional[float]
|
||||
overall_mape: Optional[float]
|
||||
overall_rmse: Optional[float]
|
||||
overall_r2_score: Optional[float]
|
||||
overall_accuracy_percentage: Optional[float]
|
||||
total_predicted_demand: float
|
||||
total_actual_demand: float
|
||||
metrics_by_product: Optional[Dict[str, Any]]
|
||||
metrics_by_location: Optional[Dict[str, Any]]
|
||||
metrics_records_created: int
|
||||
error_message: Optional[str]
|
||||
triggered_by: str
|
||||
execution_mode: str
|
||||
|
||||
|
||||
class AccuracyTrendResponse(BaseModel):
|
||||
"""Response model for accuracy trends"""
|
||||
period_days: int
|
||||
total_runs: int
|
||||
average_mape: Optional[float]
|
||||
average_accuracy: Optional[float]
|
||||
trends: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/validate-date-range"),
|
||||
response_model=ValidationResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def validate_date_range(
|
||||
validation_request: ValidationRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Validate forecasts against actual sales for a date range
|
||||
|
||||
This endpoint:
|
||||
- Fetches forecasts for the specified date range
|
||||
- Retrieves corresponding actual sales data
|
||||
- Calculates accuracy metrics (MAE, MAPE, RMSE, R², accuracy %)
|
||||
- Stores performance metrics in the database
|
||||
- Returns validation summary
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting date range validation",
|
||||
tenant_id=tenant_id,
|
||||
start_date=validation_request.start_date.isoformat(),
|
||||
end_date=validation_request.end_date.isoformat(),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
result = await validation_service.validate_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=validation_request.start_date,
|
||||
end_date=validation_request.end_date,
|
||||
orchestration_run_id=validation_request.orchestration_run_id,
|
||||
triggered_by=validation_request.triggered_by
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Date range validation completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=result.get("validation_run_id"),
|
||||
forecasts_evaluated=result.get("forecasts_evaluated")
|
||||
)
|
||||
|
||||
return ValidationResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to validate date range",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to validate forecasts: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/validate-yesterday"),
|
||||
response_model=ValidationResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def validate_yesterday(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
orchestration_run_id: Optional[UUID] = Query(None, description="Optional orchestration run ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Validate yesterday's forecasts against actual sales
|
||||
|
||||
Convenience endpoint for validating the most recent day's forecasts.
|
||||
This is typically called by the orchestrator as part of the daily workflow.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting yesterday validation",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
result = await validation_service.validate_yesterday(
|
||||
tenant_id=tenant_id,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
triggered_by="manual"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Yesterday validation completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=result.get("validation_run_id"),
|
||||
forecasts_evaluated=result.get("forecasts_evaluated")
|
||||
)
|
||||
|
||||
return ValidationResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to validate yesterday",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to validate yesterday's forecasts: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/runs/{validation_run_id}"),
|
||||
response_model=ValidationRunResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_validation_run(
|
||||
validation_run_id: UUID = Path(..., description="Validation run ID"),
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get details of a specific validation run
|
||||
|
||||
Returns complete information about a validation execution including:
|
||||
- Summary statistics
|
||||
- Overall accuracy metrics
|
||||
- Breakdown by product and location
|
||||
- Execution metadata
|
||||
"""
|
||||
try:
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
validation_run = await validation_service.get_validation_run(validation_run_id)
|
||||
|
||||
if not validation_run:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Validation run {validation_run_id} not found"
|
||||
)
|
||||
|
||||
if validation_run.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this validation run"
|
||||
)
|
||||
|
||||
return ValidationRunResponse(**validation_run.to_dict())
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get validation run",
|
||||
validation_run_id=validation_run_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get validation run: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/runs"),
|
||||
response_model=List[ValidationRunResponse],
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_validation_runs(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
limit: int = Query(50, ge=1, le=100, description="Number of records to return"),
|
||||
skip: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get validation runs for a tenant
|
||||
|
||||
Returns a list of validation executions with pagination support.
|
||||
"""
|
||||
try:
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
runs = await validation_service.get_validation_runs_by_tenant(
|
||||
tenant_id=tenant_id,
|
||||
limit=limit,
|
||||
skip=skip
|
||||
)
|
||||
|
||||
return [ValidationRunResponse(**run.to_dict()) for run in runs]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get validation runs",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get validation runs: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/trends"),
|
||||
response_model=AccuracyTrendResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_accuracy_trends(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
days: int = Query(30, ge=1, le=365, description="Number of days to analyze"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get accuracy trends over time
|
||||
|
||||
Returns validation accuracy metrics over the specified time period.
|
||||
Useful for monitoring model performance degradation and improvement.
|
||||
"""
|
||||
try:
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
trends = await validation_service.get_accuracy_trends(
|
||||
tenant_id=tenant_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
return AccuracyTrendResponse(**trends)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get accuracy trends",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get accuracy trends: {str(e)}"
|
||||
)
|
||||
174
services/forecasting/app/api/webhooks.py
Normal file
174
services/forecasting/app/api/webhooks.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/webhooks.py
|
||||
# ================================================================
|
||||
"""
|
||||
Webhooks API - Receive events from other services
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status, Header
|
||||
from typing import Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
from datetime import date
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.jobs.sales_data_listener import (
|
||||
handle_sales_import_completion,
|
||||
handle_pos_sync_completion
|
||||
)
|
||||
from shared.routing import RouteBuilder
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["webhooks"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request Schemas
|
||||
# ================================================================
|
||||
|
||||
class SalesImportWebhook(BaseModel):
|
||||
"""Webhook payload for sales data import completion"""
|
||||
tenant_id: UUID = Field(..., description="Tenant ID")
|
||||
import_job_id: str = Field(..., description="Import job ID")
|
||||
start_date: date = Field(..., description="Start date of imported data")
|
||||
end_date: date = Field(..., description="End date of imported data")
|
||||
records_count: int = Field(..., ge=0, description="Number of records imported")
|
||||
import_source: str = Field(default="import", description="Source of import")
|
||||
|
||||
|
||||
class POSSyncWebhook(BaseModel):
|
||||
"""Webhook payload for POS sync completion"""
|
||||
tenant_id: UUID = Field(..., description="Tenant ID")
|
||||
sync_log_id: str = Field(..., description="POS sync log ID")
|
||||
sync_date: date = Field(..., description="Date of synced data")
|
||||
records_synced: int = Field(..., ge=0, description="Number of records synced")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.post(
|
||||
"/webhooks/sales-import-completed",
|
||||
status_code=status.HTTP_202_ACCEPTED
|
||||
)
|
||||
async def sales_import_completed_webhook(
|
||||
payload: SalesImportWebhook,
|
||||
x_webhook_signature: Optional[str] = Header(None, description="Webhook signature for verification")
|
||||
):
|
||||
"""
|
||||
Webhook endpoint for sales data import completion
|
||||
|
||||
Called by the sales service when a data import completes.
|
||||
Triggers validation backfill for the imported date range.
|
||||
|
||||
Note: In production, this should verify the webhook signature
|
||||
to ensure the request comes from a trusted source.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Received sales import completion webhook",
|
||||
tenant_id=payload.tenant_id,
|
||||
import_job_id=payload.import_job_id,
|
||||
date_range=f"{payload.start_date} to {payload.end_date}"
|
||||
)
|
||||
|
||||
# In production, verify webhook signature here
|
||||
# if not verify_webhook_signature(x_webhook_signature, payload):
|
||||
# raise HTTPException(status_code=401, detail="Invalid webhook signature")
|
||||
|
||||
# Handle the import completion asynchronously
|
||||
result = await handle_sales_import_completion(
|
||||
tenant_id=payload.tenant_id,
|
||||
import_job_id=payload.import_job_id,
|
||||
start_date=payload.start_date,
|
||||
end_date=payload.end_date,
|
||||
records_count=payload.records_count,
|
||||
import_source=payload.import_source
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "accepted",
|
||||
"message": "Sales import completion event received and processing",
|
||||
"result": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process sales import webhook",
|
||||
payload=payload.dict(),
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to process webhook: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/webhooks/pos-sync-completed",
|
||||
status_code=status.HTTP_202_ACCEPTED
|
||||
)
|
||||
async def pos_sync_completed_webhook(
|
||||
payload: POSSyncWebhook,
|
||||
x_webhook_signature: Optional[str] = Header(None, description="Webhook signature for verification")
|
||||
):
|
||||
"""
|
||||
Webhook endpoint for POS sync completion
|
||||
|
||||
Called by the POS service when data synchronization completes.
|
||||
Triggers validation for the synced date.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Received POS sync completion webhook",
|
||||
tenant_id=payload.tenant_id,
|
||||
sync_log_id=payload.sync_log_id,
|
||||
sync_date=payload.sync_date.isoformat()
|
||||
)
|
||||
|
||||
# In production, verify webhook signature here
|
||||
# if not verify_webhook_signature(x_webhook_signature, payload):
|
||||
# raise HTTPException(status_code=401, detail="Invalid webhook signature")
|
||||
|
||||
# Handle the sync completion
|
||||
result = await handle_pos_sync_completion(
|
||||
tenant_id=payload.tenant_id,
|
||||
sync_log_id=payload.sync_log_id,
|
||||
sync_date=payload.sync_date,
|
||||
records_synced=payload.records_synced
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "accepted",
|
||||
"message": "POS sync completion event received and processing",
|
||||
"result": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process POS sync webhook",
|
||||
payload=payload.dict(),
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to process webhook: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/webhooks/health",
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
async def webhook_health_check():
|
||||
"""Health check endpoint for webhook receiver"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "forecasting-webhooks",
|
||||
"endpoints": [
|
||||
"/webhooks/sales-import-completed",
|
||||
"/webhooks/pos-sync-completed"
|
||||
]
|
||||
}
|
||||
29
services/forecasting/app/jobs/__init__.py
Normal file
29
services/forecasting/app/jobs/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Forecasting Service Jobs Package
|
||||
Scheduled and background jobs for the forecasting service
|
||||
"""
|
||||
|
||||
from .daily_validation import daily_validation_job, validate_date_range_job
|
||||
from .sales_data_listener import (
|
||||
handle_sales_import_completion,
|
||||
handle_pos_sync_completion,
|
||||
process_pending_validations
|
||||
)
|
||||
from .auto_backfill_job import (
|
||||
auto_backfill_all_tenants,
|
||||
process_all_pending_validations,
|
||||
daily_validation_maintenance_job,
|
||||
run_validation_maintenance_for_tenant
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"daily_validation_job",
|
||||
"validate_date_range_job",
|
||||
"handle_sales_import_completion",
|
||||
"handle_pos_sync_completion",
|
||||
"process_pending_validations",
|
||||
"auto_backfill_all_tenants",
|
||||
"process_all_pending_validations",
|
||||
"daily_validation_maintenance_job",
|
||||
"run_validation_maintenance_for_tenant",
|
||||
]
|
||||
275
services/forecasting/app/jobs/auto_backfill_job.py
Normal file
275
services/forecasting/app/jobs/auto_backfill_job.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/jobs/auto_backfill_job.py
|
||||
# ================================================================
|
||||
"""
|
||||
Automated Backfill Job
|
||||
|
||||
Scheduled job to automatically detect and backfill validation gaps.
|
||||
Can be run daily or weekly to ensure all historical forecasts are validated.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.services.historical_validation_service import HistoricalValidationService
|
||||
from app.core.database import database_manager
|
||||
from app.jobs.sales_data_listener import process_pending_validations
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def auto_backfill_all_tenants(
|
||||
tenant_ids: List[uuid.UUID],
|
||||
lookback_days: int = 90,
|
||||
max_gaps_per_tenant: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run auto backfill for multiple tenants
|
||||
|
||||
Args:
|
||||
tenant_ids: List of tenant IDs to process
|
||||
lookback_days: How far back to check for gaps
|
||||
max_gaps_per_tenant: Maximum number of gaps to process per tenant
|
||||
|
||||
Returns:
|
||||
Summary of backfill operations across all tenants
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting auto backfill for all tenants",
|
||||
tenant_count=len(tenant_ids),
|
||||
lookback_days=lookback_days
|
||||
)
|
||||
|
||||
results = []
|
||||
total_gaps_found = 0
|
||||
total_gaps_processed = 0
|
||||
total_successful = 0
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
try:
|
||||
async with database_manager.get_session() as db:
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
result = await service.auto_backfill_gaps(
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days,
|
||||
max_gaps_to_process=max_gaps_per_tenant
|
||||
)
|
||||
|
||||
results.append({
|
||||
"tenant_id": str(tenant_id),
|
||||
"status": "success",
|
||||
**result
|
||||
})
|
||||
|
||||
total_gaps_found += result.get("gaps_found", 0)
|
||||
total_gaps_processed += result.get("gaps_processed", 0)
|
||||
total_successful += result.get("validations_completed", 0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to auto backfill for tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
results.append({
|
||||
"tenant_id": str(tenant_id),
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Auto backfill completed for all tenants",
|
||||
tenant_count=len(tenant_ids),
|
||||
total_gaps_found=total_gaps_found,
|
||||
total_gaps_processed=total_gaps_processed,
|
||||
total_successful=total_successful
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"tenants_processed": len(tenant_ids),
|
||||
"total_gaps_found": total_gaps_found,
|
||||
"total_gaps_processed": total_gaps_processed,
|
||||
"total_validations_completed": total_successful,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Auto backfill job failed",
|
||||
error=str(e)
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
async def process_all_pending_validations(
|
||||
tenant_ids: List[uuid.UUID],
|
||||
max_per_tenant: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process all pending validations for multiple tenants
|
||||
|
||||
Args:
|
||||
tenant_ids: List of tenant IDs to process
|
||||
max_per_tenant: Maximum pending validations to process per tenant
|
||||
|
||||
Returns:
|
||||
Summary of processing results
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Processing pending validations for all tenants",
|
||||
tenant_count=len(tenant_ids)
|
||||
)
|
||||
|
||||
results = []
|
||||
total_pending = 0
|
||||
total_processed = 0
|
||||
total_successful = 0
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
try:
|
||||
result = await process_pending_validations(
|
||||
tenant_id=tenant_id,
|
||||
max_to_process=max_per_tenant
|
||||
)
|
||||
|
||||
results.append({
|
||||
"tenant_id": str(tenant_id),
|
||||
**result
|
||||
})
|
||||
|
||||
total_pending += result.get("pending_count", 0)
|
||||
total_processed += result.get("processed", 0)
|
||||
total_successful += result.get("successful", 0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process pending validations for tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
results.append({
|
||||
"tenant_id": str(tenant_id),
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Pending validations processed for all tenants",
|
||||
tenant_count=len(tenant_ids),
|
||||
total_pending=total_pending,
|
||||
total_processed=total_processed,
|
||||
total_successful=total_successful
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"tenants_processed": len(tenant_ids),
|
||||
"total_pending": total_pending,
|
||||
"total_processed": total_processed,
|
||||
"total_successful": total_successful,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process all pending validations",
|
||||
error=str(e)
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
async def daily_validation_maintenance_job(
|
||||
tenant_ids: List[uuid.UUID]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Daily validation maintenance job
|
||||
|
||||
Combines gap detection/backfill and pending validation processing.
|
||||
Recommended to run once daily (e.g., 6:00 AM after orchestrator completes).
|
||||
|
||||
Args:
|
||||
tenant_ids: List of tenant IDs to process
|
||||
|
||||
Returns:
|
||||
Summary of all maintenance operations
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting daily validation maintenance",
|
||||
tenant_count=len(tenant_ids),
|
||||
timestamp=datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
|
||||
# Step 1: Process pending validations (retry failures)
|
||||
pending_result = await process_all_pending_validations(
|
||||
tenant_ids=tenant_ids,
|
||||
max_per_tenant=10
|
||||
)
|
||||
|
||||
# Step 2: Auto backfill detected gaps
|
||||
backfill_result = await auto_backfill_all_tenants(
|
||||
tenant_ids=tenant_ids,
|
||||
lookback_days=90,
|
||||
max_gaps_per_tenant=5
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Daily validation maintenance completed",
|
||||
pending_validations_processed=pending_result.get("total_processed", 0),
|
||||
gaps_backfilled=backfill_result.get("total_validations_completed", 0)
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"tenants_processed": len(tenant_ids),
|
||||
"pending_validations": pending_result,
|
||||
"gap_backfill": backfill_result,
|
||||
"summary": {
|
||||
"total_pending_processed": pending_result.get("total_processed", 0),
|
||||
"total_gaps_backfilled": backfill_result.get("total_validations_completed", 0),
|
||||
"total_validations": (
|
||||
pending_result.get("total_processed", 0) +
|
||||
backfill_result.get("total_validations_completed", 0)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Daily validation maintenance failed",
|
||||
error=str(e)
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# Convenience function for single tenant
|
||||
async def run_validation_maintenance_for_tenant(
|
||||
tenant_id: uuid.UUID
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run validation maintenance for a single tenant
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
|
||||
Returns:
|
||||
Maintenance results
|
||||
"""
|
||||
return await daily_validation_maintenance_job([tenant_id])
|
||||
147
services/forecasting/app/jobs/daily_validation.py
Normal file
147
services/forecasting/app/jobs/daily_validation.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/jobs/daily_validation.py
|
||||
# ================================================================
|
||||
"""
|
||||
Daily Validation Job
|
||||
|
||||
Scheduled job to validate previous day's forecasts against actual sales.
|
||||
This job is called by the orchestrator as part of the daily workflow.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.services.validation_service import ValidationService
|
||||
from app.core.database import database_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def daily_validation_job(
|
||||
tenant_id: uuid.UUID,
|
||||
orchestration_run_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate yesterday's forecasts against actual sales
|
||||
|
||||
This function is designed to be called by the orchestrator as part of
|
||||
the daily workflow (Step 5: validate_previous_forecasts).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
orchestration_run_id: Optional orchestration run ID for tracking
|
||||
|
||||
Returns:
|
||||
Dictionary with validation results
|
||||
"""
|
||||
async with database_manager.get_session() as db:
|
||||
try:
|
||||
logger.info(
|
||||
"Starting daily validation job",
|
||||
tenant_id=tenant_id,
|
||||
orchestration_run_id=orchestration_run_id
|
||||
)
|
||||
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
# Validate yesterday's forecasts
|
||||
result = await validation_service.validate_yesterday(
|
||||
tenant_id=tenant_id,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
triggered_by="orchestrator"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Daily validation job completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=result.get("validation_run_id"),
|
||||
forecasts_evaluated=result.get("forecasts_evaluated"),
|
||||
forecasts_with_actuals=result.get("forecasts_with_actuals"),
|
||||
overall_mape=result.get("overall_metrics", {}).get("mape")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Daily validation job failed",
|
||||
tenant_id=tenant_id,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"tenant_id": str(tenant_id),
|
||||
"orchestration_run_id": str(orchestration_run_id) if orchestration_run_id else None
|
||||
}
|
||||
|
||||
|
||||
async def validate_date_range_job(
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
orchestration_run_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate forecasts for a specific date range
|
||||
|
||||
Useful for backfilling validation metrics when historical data is uploaded.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start of validation period
|
||||
end_date: End of validation period
|
||||
orchestration_run_id: Optional orchestration run ID for tracking
|
||||
|
||||
Returns:
|
||||
Dictionary with validation results
|
||||
"""
|
||||
async with database_manager.get_session() as db:
|
||||
try:
|
||||
logger.info(
|
||||
"Starting date range validation job",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat(),
|
||||
orchestration_run_id=orchestration_run_id
|
||||
)
|
||||
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
result = await validation_service.validate_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
triggered_by="scheduled"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Date range validation job completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=result.get("validation_run_id"),
|
||||
forecasts_evaluated=result.get("forecasts_evaluated"),
|
||||
forecasts_with_actuals=result.get("forecasts_with_actuals")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Date range validation job failed",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat(),
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"tenant_id": str(tenant_id),
|
||||
"orchestration_run_id": str(orchestration_run_id) if orchestration_run_id else None
|
||||
}
|
||||
276
services/forecasting/app/jobs/sales_data_listener.py
Normal file
276
services/forecasting/app/jobs/sales_data_listener.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/jobs/sales_data_listener.py
|
||||
# ================================================================
|
||||
"""
|
||||
Sales Data Listener
|
||||
|
||||
Listens for sales data import completions and triggers validation backfill.
|
||||
Can be called via webhook, message queue, or direct API call from sales service.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, date
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.services.historical_validation_service import HistoricalValidationService
|
||||
from app.core.database import database_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def handle_sales_import_completion(
|
||||
tenant_id: uuid.UUID,
|
||||
import_job_id: str,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
records_count: int,
|
||||
import_source: str = "import"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle sales data import completion event
|
||||
|
||||
This function is called when the sales service completes a data import.
|
||||
It registers the update and triggers validation for the imported date range.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
import_job_id: Sales import job ID
|
||||
start_date: Start date of imported data
|
||||
end_date: End date of imported data
|
||||
records_count: Number of records imported
|
||||
import_source: Source of import (csv, xlsx, api, pos_sync)
|
||||
|
||||
Returns:
|
||||
Dictionary with registration and validation results
|
||||
"""
|
||||
async with database_manager.get_session() as db:
|
||||
try:
|
||||
logger.info(
|
||||
"Handling sales import completion",
|
||||
tenant_id=tenant_id,
|
||||
import_job_id=import_job_id,
|
||||
date_range=f"{start_date} to {end_date}",
|
||||
records_count=records_count
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
# Register the sales data update and trigger validation
|
||||
result = await service.register_sales_data_update(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
records_affected=records_count,
|
||||
update_source=import_source,
|
||||
import_job_id=import_job_id,
|
||||
auto_trigger_validation=True
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Sales import completion handled",
|
||||
tenant_id=tenant_id,
|
||||
import_job_id=import_job_id,
|
||||
update_id=result.get("update_id"),
|
||||
validation_triggered=result.get("validation_triggered")
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"tenant_id": str(tenant_id),
|
||||
"import_job_id": import_job_id,
|
||||
**result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to handle sales import completion",
|
||||
tenant_id=tenant_id,
|
||||
import_job_id=import_job_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"tenant_id": str(tenant_id),
|
||||
"import_job_id": import_job_id
|
||||
}
|
||||
|
||||
|
||||
async def handle_pos_sync_completion(
|
||||
tenant_id: uuid.UUID,
|
||||
sync_log_id: str,
|
||||
sync_date: date,
|
||||
records_synced: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle POS sync completion event
|
||||
|
||||
Called when POS data is synchronized to the sales service.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
sync_log_id: POS sync log ID
|
||||
sync_date: Date of synced data
|
||||
records_synced: Number of records synced
|
||||
|
||||
Returns:
|
||||
Dictionary with registration and validation results
|
||||
"""
|
||||
async with database_manager.get_session() as db:
|
||||
try:
|
||||
logger.info(
|
||||
"Handling POS sync completion",
|
||||
tenant_id=tenant_id,
|
||||
sync_log_id=sync_log_id,
|
||||
sync_date=sync_date.isoformat(),
|
||||
records_synced=records_synced
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
# For POS syncs, we typically validate just the sync date
|
||||
result = await service.register_sales_data_update(
|
||||
tenant_id=tenant_id,
|
||||
start_date=sync_date,
|
||||
end_date=sync_date,
|
||||
records_affected=records_synced,
|
||||
update_source="pos_sync",
|
||||
import_job_id=sync_log_id,
|
||||
auto_trigger_validation=True
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"POS sync completion handled",
|
||||
tenant_id=tenant_id,
|
||||
sync_log_id=sync_log_id,
|
||||
update_id=result.get("update_id")
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"tenant_id": str(tenant_id),
|
||||
"sync_log_id": sync_log_id,
|
||||
**result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to handle POS sync completion",
|
||||
tenant_id=tenant_id,
|
||||
sync_log_id=sync_log_id,
|
||||
error=str(e)
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"tenant_id": str(tenant_id),
|
||||
"sync_log_id": sync_log_id
|
||||
}
|
||||
|
||||
|
||||
async def process_pending_validations(
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
max_to_process: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process pending validation requests
|
||||
|
||||
Can be run as a scheduled job to handle any pending validations
|
||||
that failed to trigger automatically.
|
||||
|
||||
Args:
|
||||
tenant_id: Optional tenant ID to filter (process all tenants if None)
|
||||
max_to_process: Maximum number of pending validations to process
|
||||
|
||||
Returns:
|
||||
Summary of processing results
|
||||
"""
|
||||
async with database_manager.get_session() as db:
|
||||
try:
|
||||
logger.info(
|
||||
"Processing pending validations",
|
||||
tenant_id=tenant_id,
|
||||
max_to_process=max_to_process
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
if tenant_id:
|
||||
# Process specific tenant
|
||||
pending = await service.get_pending_validations(
|
||||
tenant_id=tenant_id,
|
||||
limit=max_to_process
|
||||
)
|
||||
else:
|
||||
# Would need to implement get_all_pending_validations for all tenants
|
||||
# For now, require tenant_id
|
||||
logger.warning("Processing all tenants not implemented, tenant_id required")
|
||||
return {
|
||||
"status": "skipped",
|
||||
"message": "tenant_id required"
|
||||
}
|
||||
|
||||
if not pending:
|
||||
logger.info("No pending validations found")
|
||||
return {
|
||||
"status": "success",
|
||||
"pending_count": 0,
|
||||
"processed": 0
|
||||
}
|
||||
|
||||
results = []
|
||||
for update_record in pending:
|
||||
try:
|
||||
result = await service.backfill_validation(
|
||||
tenant_id=update_record.tenant_id,
|
||||
start_date=update_record.update_date_start,
|
||||
end_date=update_record.update_date_end,
|
||||
triggered_by="pending_processor",
|
||||
sales_data_update_id=update_record.id
|
||||
)
|
||||
results.append({
|
||||
"update_id": str(update_record.id),
|
||||
"status": "success",
|
||||
"validation_run_id": result.get("validation_run_id")
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process pending validation",
|
||||
update_id=update_record.id,
|
||||
error=str(e)
|
||||
)
|
||||
results.append({
|
||||
"update_id": str(update_record.id),
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
successful = sum(1 for r in results if r["status"] == "success")
|
||||
|
||||
logger.info(
|
||||
"Pending validations processed",
|
||||
pending_count=len(pending),
|
||||
processed=len(results),
|
||||
successful=successful
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"pending_count": len(pending),
|
||||
"processed": len(results),
|
||||
"successful": successful,
|
||||
"failed": len(results) - successful,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process pending validations",
|
||||
error=str(e)
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -15,13 +15,13 @@ from app.services.forecasting_alert_service import ForecastingAlertService
|
||||
from shared.service_base import StandardFastAPIService
|
||||
|
||||
# Import API routers
|
||||
from app.api import forecasts, forecasting_operations, analytics, scenario_operations, internal_demo, audit, ml_insights
|
||||
from app.api import forecasts, forecasting_operations, analytics, scenario_operations, internal_demo, audit, ml_insights, validation, historical_validation, webhooks, performance_monitoring, retraining
|
||||
|
||||
|
||||
class ForecastingService(StandardFastAPIService):
|
||||
"""Forecasting Service with standardized setup"""
|
||||
|
||||
expected_migration_version = "00001"
|
||||
expected_migration_version = "00003"
|
||||
|
||||
async def on_startup(self, app):
|
||||
"""Custom startup logic including migration verification"""
|
||||
@@ -45,7 +45,7 @@ class ForecastingService(StandardFastAPIService):
|
||||
def __init__(self):
|
||||
# Define expected database tables for health checks
|
||||
forecasting_expected_tables = [
|
||||
'forecasts', 'prediction_batches', 'model_performance_metrics', 'prediction_cache'
|
||||
'forecasts', 'prediction_batches', 'model_performance_metrics', 'prediction_cache', 'validation_runs', 'sales_data_updates'
|
||||
]
|
||||
|
||||
self.alert_service = None
|
||||
@@ -171,6 +171,11 @@ service.add_router(analytics.router)
|
||||
service.add_router(scenario_operations.router)
|
||||
service.add_router(internal_demo.router)
|
||||
service.add_router(ml_insights.router) # ML insights endpoint
|
||||
service.add_router(validation.router) # Validation endpoint
|
||||
service.add_router(historical_validation.router) # Historical validation endpoint
|
||||
service.add_router(webhooks.router) # Webhooks endpoint
|
||||
service.add_router(performance_monitoring.router) # Performance monitoring endpoint
|
||||
service.add_router(retraining.router) # Retraining endpoint
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@@ -14,6 +14,8 @@ AuditLog = create_audit_log_model(Base)
|
||||
# Import all models to register them with the Base metadata
|
||||
from .forecasts import Forecast, PredictionBatch
|
||||
from .predictions import ModelPerformanceMetric, PredictionCache
|
||||
from .validation_run import ValidationRun
|
||||
from .sales_data_update import SalesDataUpdate
|
||||
|
||||
# List all models for easier access
|
||||
__all__ = [
|
||||
@@ -21,5 +23,7 @@ __all__ = [
|
||||
"PredictionBatch",
|
||||
"ModelPerformanceMetric",
|
||||
"PredictionCache",
|
||||
"ValidationRun",
|
||||
"SalesDataUpdate",
|
||||
"AuditLog",
|
||||
]
|
||||
78
services/forecasting/app/models/sales_data_update.py
Normal file
78
services/forecasting/app/models/sales_data_update.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/models/sales_data_update.py
|
||||
# ================================================================
|
||||
"""
|
||||
Sales Data Update Tracking Model
|
||||
|
||||
Tracks when sales data is added or updated for past dates,
|
||||
enabling automated historical validation backfill.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Boolean, Index, Date
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
|
||||
class SalesDataUpdate(Base):
|
||||
"""Track sales data updates for historical validation"""
|
||||
__tablename__ = "sales_data_updates"
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_sales_updates_tenant_status', 'tenant_id', 'validation_status', 'created_at'),
|
||||
Index('ix_sales_updates_date_range', 'tenant_id', 'update_date_start', 'update_date_end'),
|
||||
Index('ix_sales_updates_validation_status', 'validation_status'),
|
||||
)
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Date range of sales data that was added/updated
|
||||
update_date_start = Column(Date, nullable=False, index=True)
|
||||
update_date_end = Column(Date, nullable=False, index=True)
|
||||
|
||||
# Update metadata
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
update_source = Column(String(100), nullable=True) # import, manual, pos_sync
|
||||
records_affected = Column(Integer, default=0)
|
||||
|
||||
# Validation tracking
|
||||
validation_status = Column(String(50), default="pending") # pending, processing, completed, failed
|
||||
validation_run_id = Column(UUID(as_uuid=True), nullable=True)
|
||||
validated_at = Column(DateTime(timezone=True), nullable=True)
|
||||
validation_error = Column(String(500), nullable=True)
|
||||
|
||||
# Determines if this update should trigger validation
|
||||
requires_validation = Column(Boolean, default=True)
|
||||
|
||||
# Additional context
|
||||
import_job_id = Column(String(255), nullable=True) # Link to sales import job if applicable
|
||||
notes = Column(String(500), nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<SalesDataUpdate(id={self.id}, tenant_id={self.tenant_id}, "
|
||||
f"date_range={self.update_date_start} to {self.update_date_end}, "
|
||||
f"status={self.validation_status})>"
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary for API responses"""
|
||||
return {
|
||||
'id': str(self.id),
|
||||
'tenant_id': str(self.tenant_id),
|
||||
'update_date_start': self.update_date_start.isoformat() if self.update_date_start else None,
|
||||
'update_date_end': self.update_date_end.isoformat() if self.update_date_end else None,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'update_source': self.update_source,
|
||||
'records_affected': self.records_affected,
|
||||
'validation_status': self.validation_status,
|
||||
'validation_run_id': str(self.validation_run_id) if self.validation_run_id else None,
|
||||
'validated_at': self.validated_at.isoformat() if self.validated_at else None,
|
||||
'validation_error': self.validation_error,
|
||||
'requires_validation': self.requires_validation,
|
||||
'import_job_id': self.import_job_id,
|
||||
'notes': self.notes
|
||||
}
|
||||
110
services/forecasting/app/models/validation_run.py
Normal file
110
services/forecasting/app/models/validation_run.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/models/validation_run.py
|
||||
# ================================================================
|
||||
"""
|
||||
Validation run models for tracking forecast validation executions
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Integer, Float, DateTime, Text, JSON, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
|
||||
class ValidationRun(Base):
|
||||
"""Track forecast validation execution runs"""
|
||||
__tablename__ = "validation_runs"
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_validation_runs_tenant_created', 'tenant_id', 'started_at'),
|
||||
Index('ix_validation_runs_status', 'status', 'started_at'),
|
||||
Index('ix_validation_runs_orchestration', 'orchestration_run_id'),
|
||||
)
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Link to orchestration run (if triggered by orchestrator)
|
||||
orchestration_run_id = Column(UUID(as_uuid=True), nullable=True)
|
||||
|
||||
# Validation period
|
||||
validation_start_date = Column(DateTime(timezone=True), nullable=False)
|
||||
validation_end_date = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Execution metadata
|
||||
started_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
duration_seconds = Column(Float, nullable=True)
|
||||
|
||||
# Status and results
|
||||
status = Column(String(50), default="pending") # pending, running, completed, failed
|
||||
|
||||
# Validation statistics
|
||||
total_forecasts_evaluated = Column(Integer, default=0)
|
||||
forecasts_with_actuals = Column(Integer, default=0)
|
||||
forecasts_without_actuals = Column(Integer, default=0)
|
||||
|
||||
# Accuracy metrics summary (across all validated forecasts)
|
||||
overall_mae = Column(Float, nullable=True)
|
||||
overall_mape = Column(Float, nullable=True)
|
||||
overall_rmse = Column(Float, nullable=True)
|
||||
overall_r2_score = Column(Float, nullable=True)
|
||||
overall_accuracy_percentage = Column(Float, nullable=True)
|
||||
|
||||
# Additional statistics
|
||||
total_predicted_demand = Column(Float, default=0.0)
|
||||
total_actual_demand = Column(Float, default=0.0)
|
||||
|
||||
# Breakdown by product/location (JSON)
|
||||
metrics_by_product = Column(JSON, nullable=True) # {product_id: {mae, mape, ...}}
|
||||
metrics_by_location = Column(JSON, nullable=True) # {location: {mae, mape, ...}}
|
||||
|
||||
# Performance metrics created count
|
||||
metrics_records_created = Column(Integer, default=0)
|
||||
|
||||
# Error tracking
|
||||
error_message = Column(Text, nullable=True)
|
||||
error_details = Column(JSON, nullable=True)
|
||||
|
||||
# Execution context
|
||||
triggered_by = Column(String(100), default="manual") # manual, orchestrator, scheduled
|
||||
execution_mode = Column(String(50), default="batch") # batch, single_day, real_time
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<ValidationRun(id={self.id}, tenant_id={self.tenant_id}, "
|
||||
f"status={self.status}, forecasts_evaluated={self.total_forecasts_evaluated})>"
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary for API responses"""
|
||||
return {
|
||||
'id': str(self.id),
|
||||
'tenant_id': str(self.tenant_id),
|
||||
'orchestration_run_id': str(self.orchestration_run_id) if self.orchestration_run_id else None,
|
||||
'validation_start_date': self.validation_start_date.isoformat() if self.validation_start_date else None,
|
||||
'validation_end_date': self.validation_end_date.isoformat() if self.validation_end_date else None,
|
||||
'started_at': self.started_at.isoformat() if self.started_at else None,
|
||||
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
|
||||
'duration_seconds': self.duration_seconds,
|
||||
'status': self.status,
|
||||
'total_forecasts_evaluated': self.total_forecasts_evaluated,
|
||||
'forecasts_with_actuals': self.forecasts_with_actuals,
|
||||
'forecasts_without_actuals': self.forecasts_without_actuals,
|
||||
'overall_mae': self.overall_mae,
|
||||
'overall_mape': self.overall_mape,
|
||||
'overall_rmse': self.overall_rmse,
|
||||
'overall_r2_score': self.overall_r2_score,
|
||||
'overall_accuracy_percentage': self.overall_accuracy_percentage,
|
||||
'total_predicted_demand': self.total_predicted_demand,
|
||||
'total_actual_demand': self.total_actual_demand,
|
||||
'metrics_by_product': self.metrics_by_product,
|
||||
'metrics_by_location': self.metrics_by_location,
|
||||
'metrics_records_created': self.metrics_records_created,
|
||||
'error_message': self.error_message,
|
||||
'error_details': self.error_details,
|
||||
'triggered_by': self.triggered_by,
|
||||
'execution_mode': self.execution_mode,
|
||||
}
|
||||
@@ -167,4 +167,105 @@ class PerformanceMetricRepository(ForecastingBaseRepository):
|
||||
|
||||
async def cleanup_old_metrics(self, days_old: int = 180) -> int:
|
||||
"""Clean up old performance metrics"""
|
||||
return await self.cleanup_old_records(days_old=days_old)
|
||||
return await self.cleanup_old_records(days_old=days_old)
|
||||
|
||||
async def bulk_create_metrics(self, metrics: List[ModelPerformanceMetric]) -> int:
|
||||
"""
|
||||
Bulk insert performance metrics for validation
|
||||
|
||||
Args:
|
||||
metrics: List of ModelPerformanceMetric objects to insert
|
||||
|
||||
Returns:
|
||||
Number of metrics created
|
||||
"""
|
||||
try:
|
||||
if not metrics:
|
||||
return 0
|
||||
|
||||
self.session.add_all(metrics)
|
||||
await self.session.flush()
|
||||
|
||||
logger.info(
|
||||
"Bulk created performance metrics",
|
||||
count=len(metrics)
|
||||
)
|
||||
|
||||
return len(metrics)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to bulk create performance metrics",
|
||||
count=len(metrics),
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to bulk create metrics: {str(e)}")
|
||||
|
||||
async def get_metrics_by_date_range(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
inventory_product_id: Optional[str] = None
|
||||
) -> List[ModelPerformanceMetric]:
|
||||
"""
|
||||
Get performance metrics for a date range
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start of date range
|
||||
end_date: End of date range
|
||||
inventory_product_id: Optional product filter
|
||||
|
||||
Returns:
|
||||
List of performance metrics
|
||||
"""
|
||||
try:
|
||||
filters = {
|
||||
"tenant_id": tenant_id
|
||||
}
|
||||
|
||||
if inventory_product_id:
|
||||
filters["inventory_product_id"] = inventory_product_id
|
||||
|
||||
# Build custom query for date range
|
||||
query_text = """
|
||||
SELECT *
|
||||
FROM model_performance_metrics
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND evaluation_date >= :start_date
|
||||
AND evaluation_date <= :end_date
|
||||
"""
|
||||
|
||||
params = {
|
||||
"tenant_id": tenant_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date
|
||||
}
|
||||
|
||||
if inventory_product_id:
|
||||
query_text += " AND inventory_product_id = :inventory_product_id"
|
||||
params["inventory_product_id"] = inventory_product_id
|
||||
|
||||
query_text += " ORDER BY evaluation_date DESC"
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
rows = result.fetchall()
|
||||
|
||||
# Convert rows to ModelPerformanceMetric objects
|
||||
metrics = []
|
||||
for row in rows:
|
||||
metric = ModelPerformanceMetric()
|
||||
for column in row._mapping.keys():
|
||||
setattr(metric, column, row._mapping[column])
|
||||
metrics.append(metric)
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get metrics by date range",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to get metrics: {str(e)}")
|
||||
@@ -0,0 +1,480 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/historical_validation_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Historical Validation Service
|
||||
|
||||
Handles validation backfill when historical sales data is uploaded late.
|
||||
Detects gaps in validation coverage and automatically triggers validation
|
||||
for periods where forecasts exist but haven't been validated yet.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, func, Date, or_
|
||||
from datetime import datetime, timedelta, timezone, date
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.models.forecasts import Forecast
|
||||
from app.models.validation_run import ValidationRun
|
||||
from app.models.sales_data_update import SalesDataUpdate
|
||||
from app.services.validation_service import ValidationService
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class HistoricalValidationService:
|
||||
"""Service for backfilling historical validation when sales data arrives late"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
self.validation_service = ValidationService(db_session)
|
||||
|
||||
async def detect_validation_gaps(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
lookback_days: int = 90
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Detect date ranges where forecasts exist but haven't been validated
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
lookback_days: How far back to check (default 90 days)
|
||||
|
||||
Returns:
|
||||
List of gap periods with date ranges
|
||||
"""
|
||||
try:
|
||||
end_date = datetime.now(timezone.utc)
|
||||
start_date = end_date - timedelta(days=lookback_days)
|
||||
|
||||
logger.info(
|
||||
"Detecting validation gaps",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
|
||||
# Get all dates with forecasts
|
||||
forecast_query = select(
|
||||
func.cast(Forecast.forecast_date, Date).label('forecast_date')
|
||||
).where(
|
||||
and_(
|
||||
Forecast.tenant_id == tenant_id,
|
||||
Forecast.forecast_date >= start_date,
|
||||
Forecast.forecast_date <= end_date
|
||||
)
|
||||
).group_by(
|
||||
func.cast(Forecast.forecast_date, Date)
|
||||
).order_by(
|
||||
func.cast(Forecast.forecast_date, Date)
|
||||
)
|
||||
|
||||
forecast_result = await self.db.execute(forecast_query)
|
||||
forecast_dates = {row.forecast_date for row in forecast_result.fetchall()}
|
||||
|
||||
if not forecast_dates:
|
||||
logger.info("No forecasts found in lookback period", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
# Get all dates that have been validated
|
||||
validation_query = select(
|
||||
func.cast(ValidationRun.validation_start_date, Date).label('validated_date')
|
||||
).where(
|
||||
and_(
|
||||
ValidationRun.tenant_id == tenant_id,
|
||||
ValidationRun.status == "completed",
|
||||
ValidationRun.validation_start_date >= start_date,
|
||||
ValidationRun.validation_end_date <= end_date
|
||||
)
|
||||
).group_by(
|
||||
func.cast(ValidationRun.validation_start_date, Date)
|
||||
)
|
||||
|
||||
validation_result = await self.db.execute(validation_query)
|
||||
validated_dates = {row.validated_date for row in validation_result.fetchall()}
|
||||
|
||||
# Find gaps (dates with forecasts but no validation)
|
||||
gap_dates = sorted(forecast_dates - validated_dates)
|
||||
|
||||
if not gap_dates:
|
||||
logger.info("No validation gaps found", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
# Group consecutive dates into ranges
|
||||
gaps = []
|
||||
current_gap_start = gap_dates[0]
|
||||
current_gap_end = gap_dates[0]
|
||||
|
||||
for i in range(1, len(gap_dates)):
|
||||
if (gap_dates[i] - current_gap_end).days == 1:
|
||||
# Consecutive date, extend current gap
|
||||
current_gap_end = gap_dates[i]
|
||||
else:
|
||||
# Gap in dates, save current gap and start new one
|
||||
gaps.append({
|
||||
"start_date": current_gap_start,
|
||||
"end_date": current_gap_end,
|
||||
"days_count": (current_gap_end - current_gap_start).days + 1
|
||||
})
|
||||
current_gap_start = gap_dates[i]
|
||||
current_gap_end = gap_dates[i]
|
||||
|
||||
# Don't forget the last gap
|
||||
gaps.append({
|
||||
"start_date": current_gap_start,
|
||||
"end_date": current_gap_end,
|
||||
"days_count": (current_gap_end - current_gap_start).days + 1
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Validation gaps detected",
|
||||
tenant_id=tenant_id,
|
||||
gaps_count=len(gaps),
|
||||
total_days=len(gap_dates)
|
||||
)
|
||||
|
||||
return gaps
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to detect validation gaps",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to detect validation gaps: {str(e)}")
|
||||
|
||||
async def backfill_validation(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
triggered_by: str = "manual",
|
||||
sales_data_update_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Backfill validation for a historical date range
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start date for backfill
|
||||
end_date: End date for backfill
|
||||
triggered_by: How this backfill was triggered
|
||||
sales_data_update_id: Optional link to sales data update record
|
||||
|
||||
Returns:
|
||||
Backfill results with validation summary
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting validation backfill",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat(),
|
||||
triggered_by=triggered_by
|
||||
)
|
||||
|
||||
# Convert dates to datetime
|
||||
start_datetime = datetime.combine(start_date, datetime.min.time()).replace(tzinfo=timezone.utc)
|
||||
end_datetime = datetime.combine(end_date, datetime.max.time()).replace(tzinfo=timezone.utc)
|
||||
|
||||
# Run validation for the date range
|
||||
validation_result = await self.validation_service.validate_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_datetime,
|
||||
end_date=end_datetime,
|
||||
orchestration_run_id=None,
|
||||
triggered_by=triggered_by
|
||||
)
|
||||
|
||||
# Update sales data update record if provided
|
||||
if sales_data_update_id:
|
||||
await self._update_sales_data_record(
|
||||
sales_data_update_id=sales_data_update_id,
|
||||
validation_run_id=uuid.UUID(validation_result["validation_run_id"]),
|
||||
status="completed" if validation_result["status"] == "completed" else "failed"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Validation backfill completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=validation_result.get("validation_run_id"),
|
||||
forecasts_evaluated=validation_result.get("forecasts_evaluated")
|
||||
)
|
||||
|
||||
return {
|
||||
**validation_result,
|
||||
"backfill_date_range": {
|
||||
"start": start_date.isoformat(),
|
||||
"end": end_date.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Validation backfill failed",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat(),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
if sales_data_update_id:
|
||||
await self._update_sales_data_record(
|
||||
sales_data_update_id=sales_data_update_id,
|
||||
validation_run_id=None,
|
||||
status="failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
raise DatabaseError(f"Validation backfill failed: {str(e)}")
|
||||
|
||||
async def auto_backfill_gaps(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
lookback_days: int = 90,
|
||||
max_gaps_to_process: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Automatically detect and backfill validation gaps
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
lookback_days: How far back to check
|
||||
max_gaps_to_process: Maximum number of gaps to process in one run
|
||||
|
||||
Returns:
|
||||
Summary of backfill operations
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting auto backfill",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days
|
||||
)
|
||||
|
||||
# Detect gaps
|
||||
gaps = await self.detect_validation_gaps(tenant_id, lookback_days)
|
||||
|
||||
if not gaps:
|
||||
return {
|
||||
"gaps_found": 0,
|
||||
"gaps_processed": 0,
|
||||
"validations_completed": 0,
|
||||
"message": "No validation gaps found"
|
||||
}
|
||||
|
||||
# Limit number of gaps to process
|
||||
gaps_to_process = gaps[:max_gaps_to_process]
|
||||
|
||||
results = []
|
||||
for gap in gaps_to_process:
|
||||
try:
|
||||
result = await self.backfill_validation(
|
||||
tenant_id=tenant_id,
|
||||
start_date=gap["start_date"],
|
||||
end_date=gap["end_date"],
|
||||
triggered_by="auto_backfill"
|
||||
)
|
||||
results.append({
|
||||
"gap": gap,
|
||||
"result": result,
|
||||
"status": "success"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to backfill gap",
|
||||
gap=gap,
|
||||
error=str(e)
|
||||
)
|
||||
results.append({
|
||||
"gap": gap,
|
||||
"error": str(e),
|
||||
"status": "failed"
|
||||
})
|
||||
|
||||
successful = sum(1 for r in results if r["status"] == "success")
|
||||
|
||||
logger.info(
|
||||
"Auto backfill completed",
|
||||
tenant_id=tenant_id,
|
||||
gaps_found=len(gaps),
|
||||
gaps_processed=len(results),
|
||||
successful=successful
|
||||
)
|
||||
|
||||
return {
|
||||
"gaps_found": len(gaps),
|
||||
"gaps_processed": len(results),
|
||||
"validations_completed": successful,
|
||||
"validations_failed": len(results) - successful,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Auto backfill failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Auto backfill failed: {str(e)}")
|
||||
|
||||
async def register_sales_data_update(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
records_affected: int,
|
||||
update_source: str = "import",
|
||||
import_job_id: Optional[str] = None,
|
||||
auto_trigger_validation: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Register a sales data update and optionally trigger validation
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start date of updated data
|
||||
end_date: End date of updated data
|
||||
records_affected: Number of sales records affected
|
||||
update_source: Source of update (import, manual, pos_sync)
|
||||
import_job_id: Optional import job ID
|
||||
auto_trigger_validation: Whether to automatically trigger validation
|
||||
|
||||
Returns:
|
||||
Update record and validation result if triggered
|
||||
"""
|
||||
try:
|
||||
# Create sales data update record
|
||||
update_record = SalesDataUpdate(
|
||||
tenant_id=tenant_id,
|
||||
update_date_start=start_date,
|
||||
update_date_end=end_date,
|
||||
records_affected=records_affected,
|
||||
update_source=update_source,
|
||||
import_job_id=import_job_id,
|
||||
requires_validation=auto_trigger_validation,
|
||||
validation_status="pending" if auto_trigger_validation else "not_required"
|
||||
)
|
||||
|
||||
self.db.add(update_record)
|
||||
await self.db.flush()
|
||||
|
||||
logger.info(
|
||||
"Registered sales data update",
|
||||
tenant_id=tenant_id,
|
||||
update_id=update_record.id,
|
||||
date_range=f"{start_date} to {end_date}",
|
||||
records_affected=records_affected
|
||||
)
|
||||
|
||||
result = {
|
||||
"update_id": str(update_record.id),
|
||||
"update_record": update_record.to_dict(),
|
||||
"validation_triggered": False
|
||||
}
|
||||
|
||||
# Trigger validation if requested
|
||||
if auto_trigger_validation:
|
||||
try:
|
||||
validation_result = await self.backfill_validation(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
triggered_by="sales_data_update",
|
||||
sales_data_update_id=update_record.id
|
||||
)
|
||||
|
||||
result["validation_triggered"] = True
|
||||
result["validation_result"] = validation_result
|
||||
|
||||
logger.info(
|
||||
"Validation triggered for sales data update",
|
||||
update_id=update_record.id,
|
||||
validation_run_id=validation_result.get("validation_run_id")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger validation for sales data update",
|
||||
update_id=update_record.id,
|
||||
error=str(e)
|
||||
)
|
||||
update_record.validation_status = "failed"
|
||||
update_record.validation_error = str(e)[:500]
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to register sales data update",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
await self.db.rollback()
|
||||
raise DatabaseError(f"Failed to register sales data update: {str(e)}")
|
||||
|
||||
async def _update_sales_data_record(
|
||||
self,
|
||||
sales_data_update_id: uuid.UUID,
|
||||
validation_run_id: Optional[uuid.UUID],
|
||||
status: str,
|
||||
error_message: Optional[str] = None
|
||||
):
|
||||
"""Update sales data update record with validation results"""
|
||||
try:
|
||||
query = select(SalesDataUpdate).where(SalesDataUpdate.id == sales_data_update_id)
|
||||
result = await self.db.execute(query)
|
||||
update_record = result.scalar_one_or_none()
|
||||
|
||||
if update_record:
|
||||
update_record.validation_status = status
|
||||
update_record.validation_run_id = validation_run_id
|
||||
update_record.validated_at = datetime.now(timezone.utc)
|
||||
if error_message:
|
||||
update_record.validation_error = error_message[:500]
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to update sales data record",
|
||||
sales_data_update_id=sales_data_update_id,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def get_pending_validations(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
limit: int = 50
|
||||
) -> List[SalesDataUpdate]:
|
||||
"""Get pending sales data updates that need validation"""
|
||||
try:
|
||||
query = (
|
||||
select(SalesDataUpdate)
|
||||
.where(
|
||||
and_(
|
||||
SalesDataUpdate.tenant_id == tenant_id,
|
||||
SalesDataUpdate.validation_status == "pending",
|
||||
SalesDataUpdate.requires_validation == True
|
||||
)
|
||||
)
|
||||
.order_by(SalesDataUpdate.created_at)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get pending validations",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to get pending validations: {str(e)}")
|
||||
@@ -0,0 +1,435 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/performance_monitoring_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Performance Monitoring Service
|
||||
|
||||
Monitors forecast accuracy over time and triggers actions when
|
||||
performance degrades below acceptable thresholds.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, func, desc
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.models.validation_run import ValidationRun
|
||||
from app.models.predictions import ModelPerformanceMetric
|
||||
from app.models.forecasts import Forecast
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class PerformanceMonitoringService:
|
||||
"""Service for monitoring forecast performance and triggering improvements"""
|
||||
|
||||
# Configurable thresholds
|
||||
MAPE_WARNING_THRESHOLD = 20.0 # Warning if MAPE > 20%
|
||||
MAPE_CRITICAL_THRESHOLD = 30.0 # Critical if MAPE > 30%
|
||||
MAPE_TREND_THRESHOLD = 5.0 # Alert if MAPE increases by > 5% over period
|
||||
MIN_SAMPLES_FOR_ALERT = 5 # Minimum validations before alerting
|
||||
TREND_LOOKBACK_DAYS = 30 # Days to analyze for trends
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
|
||||
async def get_accuracy_summary(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get accuracy summary for recent period
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
days: Number of days to analyze
|
||||
|
||||
Returns:
|
||||
Summary with overall metrics and trends
|
||||
"""
|
||||
try:
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
|
||||
# Get recent validation runs
|
||||
query = (
|
||||
select(ValidationRun)
|
||||
.where(
|
||||
and_(
|
||||
ValidationRun.tenant_id == tenant_id,
|
||||
ValidationRun.status == "completed",
|
||||
ValidationRun.started_at >= start_date,
|
||||
ValidationRun.forecasts_with_actuals > 0 # Only runs with actual data
|
||||
)
|
||||
)
|
||||
.order_by(desc(ValidationRun.started_at))
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
runs = result.scalars().all()
|
||||
|
||||
if not runs:
|
||||
return {
|
||||
"status": "no_data",
|
||||
"message": f"No validation runs found in last {days} days",
|
||||
"period_days": days
|
||||
}
|
||||
|
||||
# Calculate summary statistics
|
||||
total_forecasts = sum(r.total_forecasts_evaluated for r in runs)
|
||||
total_with_actuals = sum(r.forecasts_with_actuals for r in runs)
|
||||
|
||||
mape_values = [r.overall_mape for r in runs if r.overall_mape is not None]
|
||||
mae_values = [r.overall_mae for r in runs if r.overall_mae is not None]
|
||||
rmse_values = [r.overall_rmse for r in runs if r.overall_rmse is not None]
|
||||
|
||||
avg_mape = sum(mape_values) / len(mape_values) if mape_values else None
|
||||
avg_mae = sum(mae_values) / len(mae_values) if mae_values else None
|
||||
avg_rmse = sum(rmse_values) / len(rmse_values) if rmse_values else None
|
||||
|
||||
# Determine health status
|
||||
health_status = "healthy"
|
||||
if avg_mape and avg_mape > self.MAPE_CRITICAL_THRESHOLD:
|
||||
health_status = "critical"
|
||||
elif avg_mape and avg_mape > self.MAPE_WARNING_THRESHOLD:
|
||||
health_status = "warning"
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"period_days": days,
|
||||
"validation_runs": len(runs),
|
||||
"total_forecasts_evaluated": total_forecasts,
|
||||
"total_forecasts_with_actuals": total_with_actuals,
|
||||
"coverage_percentage": round(
|
||||
(total_with_actuals / total_forecasts * 100) if total_forecasts > 0 else 0, 2
|
||||
),
|
||||
"average_metrics": {
|
||||
"mape": round(avg_mape, 2) if avg_mape else None,
|
||||
"mae": round(avg_mae, 2) if avg_mae else None,
|
||||
"rmse": round(avg_rmse, 2) if avg_rmse else None,
|
||||
"accuracy_percentage": round(100 - avg_mape, 2) if avg_mape else None
|
||||
},
|
||||
"health_status": health_status,
|
||||
"thresholds": {
|
||||
"warning": self.MAPE_WARNING_THRESHOLD,
|
||||
"critical": self.MAPE_CRITICAL_THRESHOLD
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get accuracy summary",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to get accuracy summary: {str(e)}")
|
||||
|
||||
async def detect_performance_degradation(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
lookback_days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Detect if forecast performance is degrading over time
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
lookback_days: Days to analyze for trends
|
||||
|
||||
Returns:
|
||||
Degradation analysis with recommendations
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Detecting performance degradation",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days
|
||||
)
|
||||
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=lookback_days)
|
||||
|
||||
# Get validation runs ordered by time
|
||||
query = (
|
||||
select(ValidationRun)
|
||||
.where(
|
||||
and_(
|
||||
ValidationRun.tenant_id == tenant_id,
|
||||
ValidationRun.status == "completed",
|
||||
ValidationRun.started_at >= start_date,
|
||||
ValidationRun.forecasts_with_actuals > 0
|
||||
)
|
||||
)
|
||||
.order_by(ValidationRun.started_at)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
runs = list(result.scalars().all())
|
||||
|
||||
if len(runs) < self.MIN_SAMPLES_FOR_ALERT:
|
||||
return {
|
||||
"status": "insufficient_data",
|
||||
"message": f"Need at least {self.MIN_SAMPLES_FOR_ALERT} validation runs",
|
||||
"runs_found": len(runs)
|
||||
}
|
||||
|
||||
# Split into first half and second half
|
||||
midpoint = len(runs) // 2
|
||||
first_half = runs[:midpoint]
|
||||
second_half = runs[midpoint:]
|
||||
|
||||
# Calculate average MAPE for each half
|
||||
first_half_mape = sum(
|
||||
r.overall_mape for r in first_half if r.overall_mape
|
||||
) / len([r for r in first_half if r.overall_mape])
|
||||
|
||||
second_half_mape = sum(
|
||||
r.overall_mape for r in second_half if r.overall_mape
|
||||
) / len([r for r in second_half if r.overall_mape])
|
||||
|
||||
mape_change = second_half_mape - first_half_mape
|
||||
mape_change_percentage = (mape_change / first_half_mape * 100) if first_half_mape > 0 else 0
|
||||
|
||||
# Determine if degradation is significant
|
||||
is_degrading = mape_change > self.MAPE_TREND_THRESHOLD
|
||||
severity = "none"
|
||||
|
||||
if is_degrading:
|
||||
if mape_change > self.MAPE_TREND_THRESHOLD * 2:
|
||||
severity = "high"
|
||||
elif mape_change > self.MAPE_TREND_THRESHOLD:
|
||||
severity = "medium"
|
||||
|
||||
# Get products with worst performance
|
||||
poor_products = await self._identify_poor_performers(tenant_id, lookback_days)
|
||||
|
||||
result = {
|
||||
"status": "analyzed",
|
||||
"period_days": lookback_days,
|
||||
"samples_analyzed": len(runs),
|
||||
"is_degrading": is_degrading,
|
||||
"severity": severity,
|
||||
"metrics": {
|
||||
"first_period_mape": round(first_half_mape, 2),
|
||||
"second_period_mape": round(second_half_mape, 2),
|
||||
"mape_change": round(mape_change, 2),
|
||||
"mape_change_percentage": round(mape_change_percentage, 2)
|
||||
},
|
||||
"poor_performers": poor_products,
|
||||
"recommendations": []
|
||||
}
|
||||
|
||||
# Add recommendations
|
||||
if is_degrading:
|
||||
result["recommendations"].append({
|
||||
"action": "retrain_models",
|
||||
"priority": "high" if severity == "high" else "medium",
|
||||
"reason": f"MAPE increased by {abs(mape_change):.1f}% over {lookback_days} days"
|
||||
})
|
||||
|
||||
if poor_products:
|
||||
result["recommendations"].append({
|
||||
"action": "retrain_poor_performers",
|
||||
"priority": "high",
|
||||
"reason": f"{len(poor_products)} products with MAPE > {self.MAPE_CRITICAL_THRESHOLD}%",
|
||||
"products": poor_products[:10] # Top 10 worst
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Performance degradation analysis complete",
|
||||
tenant_id=tenant_id,
|
||||
is_degrading=is_degrading,
|
||||
severity=severity,
|
||||
poor_performers=len(poor_products)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to detect performance degradation",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to detect degradation: {str(e)}")
|
||||
|
||||
async def _identify_poor_performers(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
lookback_days: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Identify products/locations with poor accuracy"""
|
||||
try:
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=lookback_days)
|
||||
|
||||
# Get recent performance metrics grouped by product
|
||||
query = select(
|
||||
ModelPerformanceMetric.inventory_product_id,
|
||||
func.avg(ModelPerformanceMetric.mape).label('avg_mape'),
|
||||
func.avg(ModelPerformanceMetric.mae).label('avg_mae'),
|
||||
func.count(ModelPerformanceMetric.id).label('sample_count')
|
||||
).where(
|
||||
and_(
|
||||
ModelPerformanceMetric.tenant_id == tenant_id,
|
||||
ModelPerformanceMetric.created_at >= start_date,
|
||||
ModelPerformanceMetric.mape.isnot(None)
|
||||
)
|
||||
).group_by(
|
||||
ModelPerformanceMetric.inventory_product_id
|
||||
).having(
|
||||
func.avg(ModelPerformanceMetric.mape) > self.MAPE_CRITICAL_THRESHOLD
|
||||
).order_by(
|
||||
desc(func.avg(ModelPerformanceMetric.mape))
|
||||
).limit(20)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
poor_performers = []
|
||||
|
||||
for row in result.fetchall():
|
||||
poor_performers.append({
|
||||
"inventory_product_id": str(row.inventory_product_id),
|
||||
"avg_mape": round(row.avg_mape, 2),
|
||||
"avg_mae": round(row.avg_mae, 2),
|
||||
"sample_count": row.sample_count,
|
||||
"requires_retraining": True
|
||||
})
|
||||
|
||||
return poor_performers
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to identify poor performers", error=str(e))
|
||||
return []
|
||||
|
||||
async def check_model_age(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
max_age_days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if models are outdated and need retraining
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
max_age_days: Maximum acceptable model age
|
||||
|
||||
Returns:
|
||||
Analysis of model ages
|
||||
"""
|
||||
try:
|
||||
# Get distinct models used in recent forecasts
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=7)
|
||||
|
||||
query = select(
|
||||
Forecast.model_id,
|
||||
Forecast.model_version,
|
||||
Forecast.inventory_product_id,
|
||||
func.max(Forecast.created_at).label('last_used'),
|
||||
func.count(Forecast.id).label('forecast_count')
|
||||
).where(
|
||||
and_(
|
||||
Forecast.tenant_id == tenant_id,
|
||||
Forecast.created_at >= cutoff_date
|
||||
)
|
||||
).group_by(
|
||||
Forecast.model_id,
|
||||
Forecast.model_version,
|
||||
Forecast.inventory_product_id
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
|
||||
models_info = []
|
||||
outdated_count = 0
|
||||
|
||||
for row in result.fetchall():
|
||||
# Check age against training service (would need to query training service)
|
||||
# For now, assume models older than max_age_days need retraining
|
||||
models_info.append({
|
||||
"model_id": row.model_id,
|
||||
"model_version": row.model_version,
|
||||
"inventory_product_id": str(row.inventory_product_id),
|
||||
"last_used": row.last_used.isoformat(),
|
||||
"forecast_count": row.forecast_count
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "analyzed",
|
||||
"models_in_use": len(models_info),
|
||||
"outdated_models": outdated_count,
|
||||
"max_age_days": max_age_days,
|
||||
"models": models_info[:20] # Top 20
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check model age", error=str(e))
|
||||
raise DatabaseError(f"Failed to check model age: {str(e)}")
|
||||
|
||||
async def generate_performance_report(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate comprehensive performance report
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
days: Analysis period
|
||||
|
||||
Returns:
|
||||
Complete performance report with recommendations
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Generating performance report",
|
||||
tenant_id=tenant_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
# Get all analyses
|
||||
summary = await self.get_accuracy_summary(tenant_id, days)
|
||||
degradation = await self.detect_performance_degradation(tenant_id, days)
|
||||
model_age = await self.check_model_age(tenant_id)
|
||||
|
||||
# Compile recommendations
|
||||
all_recommendations = []
|
||||
|
||||
if summary.get("health_status") == "critical":
|
||||
all_recommendations.append({
|
||||
"priority": "critical",
|
||||
"action": "immediate_review",
|
||||
"reason": f"Overall MAPE is {summary['average_metrics']['mape']}%",
|
||||
"details": "Forecast accuracy is critically low"
|
||||
})
|
||||
|
||||
all_recommendations.extend(degradation.get("recommendations", []))
|
||||
|
||||
report = {
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"tenant_id": str(tenant_id),
|
||||
"analysis_period_days": days,
|
||||
"summary": summary,
|
||||
"degradation_analysis": degradation,
|
||||
"model_age_analysis": model_age,
|
||||
"recommendations": all_recommendations,
|
||||
"requires_action": len(all_recommendations) > 0
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Performance report generated",
|
||||
tenant_id=tenant_id,
|
||||
health_status=summary.get("health_status"),
|
||||
recommendations_count=len(all_recommendations)
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to generate performance report",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to generate report: {str(e)}")
|
||||
384
services/forecasting/app/services/retraining_trigger_service.py
Normal file
384
services/forecasting/app/services/retraining_trigger_service.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/retraining_trigger_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Retraining Trigger Service
|
||||
|
||||
Automatically triggers model retraining based on performance metrics,
|
||||
accuracy degradation, or data availability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from app.services.performance_monitoring_service import PerformanceMonitoringService
|
||||
from shared.clients.training_client import TrainingServiceClient
|
||||
from shared.config.base import BaseServiceSettings
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RetrainingTriggerService:
|
||||
"""Service for triggering automatic model retraining"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
self.performance_service = PerformanceMonitoringService(db_session)
|
||||
|
||||
# Initialize training client
|
||||
config = BaseServiceSettings()
|
||||
self.training_client = TrainingServiceClient(config, calling_service_name="forecasting")
|
||||
|
||||
async def evaluate_and_trigger_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
auto_trigger: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate performance and trigger retraining if needed
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
auto_trigger: Whether to automatically trigger retraining
|
||||
|
||||
Returns:
|
||||
Evaluation results and retraining actions taken
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Evaluating retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=auto_trigger
|
||||
)
|
||||
|
||||
# Generate performance report
|
||||
report = await self.performance_service.generate_performance_report(
|
||||
tenant_id=tenant_id,
|
||||
days=30
|
||||
)
|
||||
|
||||
if not report.get("requires_action"):
|
||||
logger.info(
|
||||
"No retraining required",
|
||||
tenant_id=tenant_id,
|
||||
health_status=report["summary"].get("health_status")
|
||||
)
|
||||
return {
|
||||
"status": "no_action_needed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"health_status": report["summary"].get("health_status"),
|
||||
"report": report
|
||||
}
|
||||
|
||||
# Extract products that need retraining
|
||||
products_to_retrain = []
|
||||
recommendations = report.get("recommendations", [])
|
||||
|
||||
for rec in recommendations:
|
||||
if rec.get("action") == "retrain_poor_performers":
|
||||
products_to_retrain.extend(rec.get("products", []))
|
||||
|
||||
if not products_to_retrain and auto_trigger:
|
||||
# If degradation detected but no specific products, consider retraining all
|
||||
degradation = report.get("degradation_analysis", {})
|
||||
if degradation.get("is_degrading") and degradation.get("severity") in ["high", "medium"]:
|
||||
logger.info(
|
||||
"General degradation detected, considering full retraining",
|
||||
tenant_id=tenant_id,
|
||||
severity=degradation.get("severity")
|
||||
)
|
||||
|
||||
retraining_results = []
|
||||
|
||||
if auto_trigger and products_to_retrain:
|
||||
# Trigger retraining for poor performers
|
||||
for product in products_to_retrain:
|
||||
try:
|
||||
result = await self._trigger_product_retraining(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=uuid.UUID(product["inventory_product_id"]),
|
||||
reason=f"MAPE {product['avg_mape']}% exceeds threshold",
|
||||
priority="high"
|
||||
)
|
||||
retraining_results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger retraining for product",
|
||||
product_id=product["inventory_product_id"],
|
||||
error=str(e)
|
||||
)
|
||||
retraining_results.append({
|
||||
"product_id": product["inventory_product_id"],
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Retraining evaluation complete",
|
||||
tenant_id=tenant_id,
|
||||
products_evaluated=len(products_to_retrain),
|
||||
retraining_triggered=len(retraining_results)
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "evaluated",
|
||||
"tenant_id": str(tenant_id),
|
||||
"requires_action": report.get("requires_action"),
|
||||
"products_needing_retraining": len(products_to_retrain),
|
||||
"retraining_triggered": len(retraining_results),
|
||||
"auto_trigger_enabled": auto_trigger,
|
||||
"retraining_results": retraining_results,
|
||||
"performance_report": report
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to evaluate and trigger retraining",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to evaluate retraining: {str(e)}")
|
||||
|
||||
async def _trigger_product_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
inventory_product_id: uuid.UUID,
|
||||
reason: str,
|
||||
priority: str = "normal"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Trigger retraining for a specific product
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Product to retrain
|
||||
reason: Reason for retraining
|
||||
priority: Priority level (low, normal, high)
|
||||
|
||||
Returns:
|
||||
Retraining trigger result
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
reason=reason,
|
||||
priority=priority
|
||||
)
|
||||
|
||||
# Call training service to trigger retraining
|
||||
result = await self.training_client.trigger_retrain(
|
||||
tenant_id=str(tenant_id),
|
||||
inventory_product_id=str(inventory_product_id),
|
||||
reason=reason,
|
||||
priority=priority
|
||||
)
|
||||
|
||||
if result:
|
||||
logger.info(
|
||||
"Retraining triggered successfully",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
training_job_id=result.get("training_job_id")
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "triggered",
|
||||
"product_id": str(inventory_product_id),
|
||||
"training_job_id": result.get("training_job_id"),
|
||||
"reason": reason,
|
||||
"priority": priority,
|
||||
"triggered_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
else:
|
||||
logger.warning(
|
||||
"Retraining trigger returned no result",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id
|
||||
)
|
||||
return {
|
||||
"status": "no_response",
|
||||
"product_id": str(inventory_product_id),
|
||||
"reason": reason
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
error=str(e)
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"product_id": str(inventory_product_id),
|
||||
"error": str(e),
|
||||
"reason": reason
|
||||
}
|
||||
|
||||
async def trigger_bulk_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
product_ids: List[uuid.UUID],
|
||||
reason: str = "Bulk retraining requested"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Trigger retraining for multiple products
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_ids: List of products to retrain
|
||||
reason: Reason for bulk retraining
|
||||
|
||||
Returns:
|
||||
Bulk retraining results
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering bulk retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_count=len(product_ids)
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
for product_id in product_ids:
|
||||
result = await self._trigger_product_retraining(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=product_id,
|
||||
reason=reason,
|
||||
priority="normal"
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
successful = sum(1 for r in results if r["status"] == "triggered")
|
||||
|
||||
logger.info(
|
||||
"Bulk retraining completed",
|
||||
tenant_id=tenant_id,
|
||||
total=len(product_ids),
|
||||
successful=successful,
|
||||
failed=len(product_ids) - successful
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"total_products": len(product_ids),
|
||||
"successful": successful,
|
||||
"failed": len(product_ids) - successful,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Bulk retraining failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Bulk retraining failed: {str(e)}")
|
||||
|
||||
async def check_and_trigger_scheduled_retraining(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
max_model_age_days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Check model ages and trigger retraining for outdated models
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
max_model_age_days: Maximum acceptable model age
|
||||
|
||||
Returns:
|
||||
Scheduled retraining results
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Checking for scheduled retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
max_model_age_days=max_model_age_days
|
||||
)
|
||||
|
||||
# Get model age analysis
|
||||
model_age_analysis = await self.performance_service.check_model_age(
|
||||
tenant_id=tenant_id,
|
||||
max_age_days=max_model_age_days
|
||||
)
|
||||
|
||||
outdated_count = model_age_analysis.get("outdated_models", 0)
|
||||
|
||||
if outdated_count == 0:
|
||||
logger.info(
|
||||
"No outdated models found",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return {
|
||||
"status": "no_action_needed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"outdated_models": 0
|
||||
}
|
||||
|
||||
# TODO: Trigger retraining for outdated models
|
||||
# Would need to get list of outdated products from training service
|
||||
|
||||
return {
|
||||
"status": "analyzed",
|
||||
"tenant_id": str(tenant_id),
|
||||
"outdated_models": outdated_count,
|
||||
"message": "Scheduled retraining analysis complete"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Scheduled retraining check failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Scheduled retraining check failed: {str(e)}")
|
||||
|
||||
async def get_retraining_recommendations(
|
||||
self,
|
||||
tenant_id: uuid.UUID
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get retraining recommendations without triggering
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
|
||||
Returns:
|
||||
Recommendations for manual review
|
||||
"""
|
||||
try:
|
||||
# Evaluate without auto-triggering
|
||||
result = await self.evaluate_and_trigger_retraining(
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=False
|
||||
)
|
||||
|
||||
# Extract just the recommendations
|
||||
report = result.get("performance_report", {})
|
||||
recommendations = report.get("recommendations", [])
|
||||
|
||||
return {
|
||||
"tenant_id": str(tenant_id),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"requires_action": result.get("requires_action", False),
|
||||
"recommendations": recommendations,
|
||||
"summary": report.get("summary", {}),
|
||||
"degradation_detected": report.get("degradation_analysis", {}).get("is_degrading", False)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get retraining recommendations",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to get recommendations: {str(e)}")
|
||||
97
services/forecasting/app/services/sales_client.py
Normal file
97
services/forecasting/app/services/sales_client.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/sales_client.py
|
||||
# ================================================================
|
||||
"""
|
||||
Sales Client for Forecasting Service
|
||||
Wrapper around shared sales client with forecasting-specific methods
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from datetime import datetime
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from shared.clients.sales_client import SalesServiceClient
|
||||
from shared.config.base import BaseServiceSettings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SalesClient:
|
||||
"""Client for fetching sales data from sales service"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize sales client"""
|
||||
# Load configuration
|
||||
config = BaseServiceSettings()
|
||||
self.client = SalesServiceClient(config, calling_service_name="forecasting")
|
||||
|
||||
async def get_sales_by_date_range(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
product_id: uuid.UUID = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get sales data for a date range
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start of date range
|
||||
end_date: End of date range
|
||||
product_id: Optional product filter
|
||||
|
||||
Returns:
|
||||
List of sales records
|
||||
"""
|
||||
try:
|
||||
# Convert datetime to ISO format strings
|
||||
start_date_str = start_date.isoformat() if start_date else None
|
||||
end_date_str = end_date.isoformat() if end_date else None
|
||||
product_id_str = str(product_id) if product_id else None
|
||||
|
||||
# Use the paginated method to get all sales data
|
||||
sales_data = await self.client.get_all_sales_data(
|
||||
tenant_id=str(tenant_id),
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
product_id=product_id_str,
|
||||
aggregation="none", # Get raw data without aggregation
|
||||
page_size=1000,
|
||||
max_pages=100
|
||||
)
|
||||
|
||||
if not sales_data:
|
||||
logger.info(
|
||||
"No sales data found for date range",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"Retrieved sales data",
|
||||
tenant_id=tenant_id,
|
||||
records_count=len(sales_data),
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
|
||||
return sales_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to fetch sales data",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
# Return empty list instead of raising to allow validation to continue
|
||||
return []
|
||||
|
||||
async def close(self):
|
||||
"""Close the client connection"""
|
||||
if hasattr(self.client, 'close'):
|
||||
await self.client.close()
|
||||
586
services/forecasting/app/services/validation_service.py
Normal file
586
services/forecasting/app/services/validation_service.py
Normal file
@@ -0,0 +1,586 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/validation_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Forecast Validation Service
|
||||
|
||||
Compares historical forecasts with actual sales data to:
|
||||
1. Calculate accuracy metrics (MAE, MAPE, RMSE, R², accuracy percentage)
|
||||
2. Store performance metrics in the database
|
||||
3. Track validation runs for audit purposes
|
||||
4. Enable continuous model improvement
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, func, Date
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
import math
|
||||
import uuid
|
||||
|
||||
from app.models.forecasts import Forecast
|
||||
from app.models.predictions import ModelPerformanceMetric
|
||||
from app.models.validation_run import ValidationRun
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ValidationService:
|
||||
"""Service for validating forecasts against actual sales data"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db = db_session
|
||||
|
||||
async def validate_date_range(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
orchestration_run_id: Optional[uuid.UUID] = None,
|
||||
triggered_by: str = "manual"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate forecasts against actual sales for a date range
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start of validation period
|
||||
end_date: End of validation period
|
||||
orchestration_run_id: Optional link to orchestration run
|
||||
triggered_by: How this validation was triggered (manual, orchestrator, scheduled)
|
||||
|
||||
Returns:
|
||||
Dictionary with validation results and metrics
|
||||
"""
|
||||
validation_run = None
|
||||
|
||||
try:
|
||||
# Create validation run record
|
||||
validation_run = ValidationRun(
|
||||
tenant_id=tenant_id,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
validation_start_date=start_date,
|
||||
validation_end_date=end_date,
|
||||
status="running",
|
||||
triggered_by=triggered_by,
|
||||
execution_mode="batch"
|
||||
)
|
||||
self.db.add(validation_run)
|
||||
await self.db.flush()
|
||||
|
||||
logger.info(
|
||||
"Starting forecast validation",
|
||||
validation_run_id=validation_run.id,
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
|
||||
# Fetch forecasts with matching sales data
|
||||
forecasts_with_sales = await self._fetch_forecasts_with_sales(
|
||||
tenant_id, start_date, end_date
|
||||
)
|
||||
|
||||
if not forecasts_with_sales:
|
||||
logger.warning(
|
||||
"No forecasts with matching sales data found",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat()
|
||||
)
|
||||
validation_run.status = "completed"
|
||||
validation_run.completed_at = datetime.now(timezone.utc)
|
||||
validation_run.duration_seconds = (
|
||||
validation_run.completed_at - validation_run.started_at
|
||||
).total_seconds()
|
||||
await self.db.commit()
|
||||
|
||||
return {
|
||||
"validation_run_id": str(validation_run.id),
|
||||
"status": "completed",
|
||||
"message": "No forecasts with matching sales data found",
|
||||
"forecasts_evaluated": 0,
|
||||
"metrics_created": 0
|
||||
}
|
||||
|
||||
# Calculate metrics and create performance records
|
||||
metrics_results = await self._calculate_and_store_metrics(
|
||||
forecasts_with_sales, validation_run.id
|
||||
)
|
||||
|
||||
# Update validation run with results
|
||||
validation_run.total_forecasts_evaluated = metrics_results["total_evaluated"]
|
||||
validation_run.forecasts_with_actuals = metrics_results["with_actuals"]
|
||||
validation_run.forecasts_without_actuals = metrics_results["without_actuals"]
|
||||
validation_run.overall_mae = metrics_results["overall_mae"]
|
||||
validation_run.overall_mape = metrics_results["overall_mape"]
|
||||
validation_run.overall_rmse = metrics_results["overall_rmse"]
|
||||
validation_run.overall_r2_score = metrics_results["overall_r2_score"]
|
||||
validation_run.overall_accuracy_percentage = metrics_results["overall_accuracy_percentage"]
|
||||
validation_run.total_predicted_demand = metrics_results["total_predicted"]
|
||||
validation_run.total_actual_demand = metrics_results["total_actual"]
|
||||
validation_run.metrics_by_product = metrics_results["metrics_by_product"]
|
||||
validation_run.metrics_by_location = metrics_results["metrics_by_location"]
|
||||
validation_run.metrics_records_created = metrics_results["metrics_created"]
|
||||
validation_run.status = "completed"
|
||||
validation_run.completed_at = datetime.now(timezone.utc)
|
||||
validation_run.duration_seconds = (
|
||||
validation_run.completed_at - validation_run.started_at
|
||||
).total_seconds()
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
logger.info(
|
||||
"Forecast validation completed successfully",
|
||||
validation_run_id=validation_run.id,
|
||||
forecasts_evaluated=validation_run.total_forecasts_evaluated,
|
||||
metrics_created=validation_run.metrics_records_created,
|
||||
overall_mape=validation_run.overall_mape,
|
||||
duration_seconds=validation_run.duration_seconds
|
||||
)
|
||||
|
||||
# Extract poor accuracy products (MAPE > 30%)
|
||||
poor_accuracy_products = []
|
||||
if validation_run.metrics_by_product:
|
||||
for product_id, product_metrics in validation_run.metrics_by_product.items():
|
||||
if product_metrics.get("mape", 0) > 30:
|
||||
poor_accuracy_products.append({
|
||||
"product_id": product_id,
|
||||
"mape": product_metrics.get("mape"),
|
||||
"mae": product_metrics.get("mae"),
|
||||
"accuracy_percentage": product_metrics.get("accuracy_percentage")
|
||||
})
|
||||
|
||||
return {
|
||||
"validation_run_id": str(validation_run.id),
|
||||
"status": "completed",
|
||||
"forecasts_evaluated": validation_run.total_forecasts_evaluated,
|
||||
"forecasts_with_actuals": validation_run.forecasts_with_actuals,
|
||||
"forecasts_without_actuals": validation_run.forecasts_without_actuals,
|
||||
"metrics_created": validation_run.metrics_records_created,
|
||||
"overall_metrics": {
|
||||
"mae": validation_run.overall_mae,
|
||||
"mape": validation_run.overall_mape,
|
||||
"rmse": validation_run.overall_rmse,
|
||||
"r2_score": validation_run.overall_r2_score,
|
||||
"accuracy_percentage": validation_run.overall_accuracy_percentage
|
||||
},
|
||||
"total_predicted_demand": validation_run.total_predicted_demand,
|
||||
"total_actual_demand": validation_run.total_actual_demand,
|
||||
"duration_seconds": validation_run.duration_seconds,
|
||||
"poor_accuracy_products": poor_accuracy_products,
|
||||
"metrics_by_product": validation_run.metrics_by_product,
|
||||
"metrics_by_location": validation_run.metrics_by_location
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Forecast validation failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
|
||||
if validation_run:
|
||||
validation_run.status = "failed"
|
||||
validation_run.error_message = str(e)
|
||||
validation_run.error_details = {"error_type": type(e).__name__}
|
||||
validation_run.completed_at = datetime.now(timezone.utc)
|
||||
validation_run.duration_seconds = (
|
||||
validation_run.completed_at - validation_run.started_at
|
||||
).total_seconds()
|
||||
await self.db.commit()
|
||||
|
||||
raise DatabaseError(f"Forecast validation failed: {str(e)}")
|
||||
|
||||
async def validate_yesterday(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
orchestration_run_id: Optional[uuid.UUID] = None,
|
||||
triggered_by: str = "orchestrator"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convenience method to validate yesterday's forecasts
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
orchestration_run_id: Optional link to orchestration run
|
||||
triggered_by: How this validation was triggered
|
||||
|
||||
Returns:
|
||||
Dictionary with validation results
|
||||
"""
|
||||
yesterday = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
start_date = yesterday.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = yesterday.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
|
||||
return await self.validate_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
triggered_by=triggered_by
|
||||
)
|
||||
|
||||
async def _fetch_forecasts_with_sales(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
start_date: datetime,
|
||||
end_date: datetime
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch forecasts with their corresponding actual sales data
|
||||
|
||||
Returns list of dictionaries containing forecast and sales data
|
||||
"""
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from app.services.sales_client import SalesClient
|
||||
|
||||
# Query to get all forecasts in the date range
|
||||
query = select(Forecast).where(
|
||||
and_(
|
||||
Forecast.tenant_id == tenant_id,
|
||||
func.cast(Forecast.forecast_date, Date) >= start_date.date(),
|
||||
func.cast(Forecast.forecast_date, Date) <= end_date.date()
|
||||
)
|
||||
).order_by(Forecast.forecast_date, Forecast.inventory_product_id)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
forecasts = result.scalars().all()
|
||||
|
||||
if not forecasts:
|
||||
return []
|
||||
|
||||
# Fetch actual sales data from sales service
|
||||
sales_client = SalesClient()
|
||||
sales_data = await sales_client.get_sales_by_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
# Create lookup dict: (product_id, date) -> sales quantity
|
||||
sales_lookup = {}
|
||||
for sale in sales_data:
|
||||
sale_date = sale['date']
|
||||
if isinstance(sale_date, str):
|
||||
sale_date = datetime.fromisoformat(sale_date.replace('Z', '+00:00'))
|
||||
|
||||
key = (str(sale['inventory_product_id']), sale_date.date())
|
||||
|
||||
# Sum quantities if multiple sales records for same product/date
|
||||
if key in sales_lookup:
|
||||
sales_lookup[key]['quantity_sold'] += sale['quantity_sold']
|
||||
else:
|
||||
sales_lookup[key] = sale
|
||||
|
||||
# Match forecasts with sales data
|
||||
forecasts_with_sales = []
|
||||
for forecast in forecasts:
|
||||
forecast_date = forecast.forecast_date.date() if hasattr(forecast.forecast_date, 'date') else forecast.forecast_date
|
||||
key = (str(forecast.inventory_product_id), forecast_date)
|
||||
|
||||
sales_record = sales_lookup.get(key)
|
||||
|
||||
forecasts_with_sales.append({
|
||||
"forecast_id": forecast.id,
|
||||
"tenant_id": forecast.tenant_id,
|
||||
"inventory_product_id": forecast.inventory_product_id,
|
||||
"product_name": forecast.product_name,
|
||||
"location": forecast.location,
|
||||
"forecast_date": forecast.forecast_date,
|
||||
"predicted_demand": forecast.predicted_demand,
|
||||
"confidence_lower": forecast.confidence_lower,
|
||||
"confidence_upper": forecast.confidence_upper,
|
||||
"model_id": forecast.model_id,
|
||||
"model_version": forecast.model_version,
|
||||
"algorithm": forecast.algorithm,
|
||||
"created_at": forecast.created_at,
|
||||
"actual_sales": sales_record['quantity_sold'] if sales_record else None,
|
||||
"has_actual_data": sales_record is not None
|
||||
})
|
||||
|
||||
return forecasts_with_sales
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to fetch forecasts with sales data",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise DatabaseError(f"Failed to fetch forecast and sales data: {str(e)}")
|
||||
|
||||
async def _calculate_and_store_metrics(
|
||||
self,
|
||||
forecasts_with_sales: List[Dict[str, Any]],
|
||||
validation_run_id: uuid.UUID
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate accuracy metrics and store them in the database
|
||||
|
||||
Returns summary of metrics calculated
|
||||
"""
|
||||
# Separate forecasts with and without actual data
|
||||
forecasts_with_actuals = [f for f in forecasts_with_sales if f["has_actual_data"]]
|
||||
forecasts_without_actuals = [f for f in forecasts_with_sales if not f["has_actual_data"]]
|
||||
|
||||
if not forecasts_with_actuals:
|
||||
return {
|
||||
"total_evaluated": len(forecasts_with_sales),
|
||||
"with_actuals": 0,
|
||||
"without_actuals": len(forecasts_without_actuals),
|
||||
"metrics_created": 0,
|
||||
"overall_mae": None,
|
||||
"overall_mape": None,
|
||||
"overall_rmse": None,
|
||||
"overall_r2_score": None,
|
||||
"overall_accuracy_percentage": None,
|
||||
"total_predicted": 0.0,
|
||||
"total_actual": 0.0,
|
||||
"metrics_by_product": {},
|
||||
"metrics_by_location": {}
|
||||
}
|
||||
|
||||
# Calculate individual metrics and prepare for bulk insert
|
||||
performance_metrics = []
|
||||
errors = []
|
||||
|
||||
for forecast in forecasts_with_actuals:
|
||||
predicted = forecast["predicted_demand"]
|
||||
actual = forecast["actual_sales"]
|
||||
|
||||
# Calculate error
|
||||
error = abs(predicted - actual)
|
||||
errors.append(error)
|
||||
|
||||
# Calculate percentage error (avoid division by zero)
|
||||
percentage_error = (error / actual * 100) if actual > 0 else 0.0
|
||||
|
||||
# Calculate individual metrics
|
||||
mae = error
|
||||
mape = percentage_error
|
||||
rmse = error # Will be squared and averaged later
|
||||
|
||||
# Calculate accuracy percentage (100% - MAPE, capped at 0)
|
||||
accuracy_percentage = max(0.0, 100.0 - mape)
|
||||
|
||||
# Create performance metric record
|
||||
metric = ModelPerformanceMetric(
|
||||
model_id=uuid.UUID(forecast["model_id"]) if isinstance(forecast["model_id"], str) else forecast["model_id"],
|
||||
tenant_id=forecast["tenant_id"],
|
||||
inventory_product_id=forecast["inventory_product_id"],
|
||||
mae=mae,
|
||||
mape=mape,
|
||||
rmse=rmse,
|
||||
accuracy_score=accuracy_percentage / 100.0, # Store as 0-1 scale
|
||||
evaluation_date=forecast["forecast_date"],
|
||||
evaluation_period_start=forecast["forecast_date"],
|
||||
evaluation_period_end=forecast["forecast_date"],
|
||||
sample_size=1
|
||||
)
|
||||
performance_metrics.append(metric)
|
||||
|
||||
# Bulk insert all performance metrics
|
||||
if performance_metrics:
|
||||
self.db.add_all(performance_metrics)
|
||||
await self.db.flush()
|
||||
|
||||
# Calculate overall metrics
|
||||
overall_metrics = self._calculate_overall_metrics(forecasts_with_actuals)
|
||||
|
||||
# Calculate metrics by product
|
||||
metrics_by_product = self._calculate_metrics_by_dimension(
|
||||
forecasts_with_actuals, "inventory_product_id"
|
||||
)
|
||||
|
||||
# Calculate metrics by location
|
||||
metrics_by_location = self._calculate_metrics_by_dimension(
|
||||
forecasts_with_actuals, "location"
|
||||
)
|
||||
|
||||
return {
|
||||
"total_evaluated": len(forecasts_with_sales),
|
||||
"with_actuals": len(forecasts_with_actuals),
|
||||
"without_actuals": len(forecasts_without_actuals),
|
||||
"metrics_created": len(performance_metrics),
|
||||
"overall_mae": overall_metrics["mae"],
|
||||
"overall_mape": overall_metrics["mape"],
|
||||
"overall_rmse": overall_metrics["rmse"],
|
||||
"overall_r2_score": overall_metrics["r2_score"],
|
||||
"overall_accuracy_percentage": overall_metrics["accuracy_percentage"],
|
||||
"total_predicted": overall_metrics["total_predicted"],
|
||||
"total_actual": overall_metrics["total_actual"],
|
||||
"metrics_by_product": metrics_by_product,
|
||||
"metrics_by_location": metrics_by_location
|
||||
}
|
||||
|
||||
def _calculate_overall_metrics(self, forecasts: List[Dict[str, Any]]) -> Dict[str, float]:
|
||||
"""Calculate aggregated metrics across all forecasts"""
|
||||
if not forecasts:
|
||||
return {
|
||||
"mae": None, "mape": None, "rmse": None,
|
||||
"r2_score": None, "accuracy_percentage": None,
|
||||
"total_predicted": 0.0, "total_actual": 0.0
|
||||
}
|
||||
|
||||
predicted_values = [f["predicted_demand"] for f in forecasts]
|
||||
actual_values = [f["actual_sales"] for f in forecasts]
|
||||
|
||||
# MAE: Mean Absolute Error
|
||||
mae = sum(abs(p - a) for p, a in zip(predicted_values, actual_values)) / len(forecasts)
|
||||
|
||||
# MAPE: Mean Absolute Percentage Error (handle division by zero)
|
||||
mape_values = [
|
||||
abs(p - a) / a * 100 if a > 0 else 0.0
|
||||
for p, a in zip(predicted_values, actual_values)
|
||||
]
|
||||
mape = sum(mape_values) / len(mape_values)
|
||||
|
||||
# RMSE: Root Mean Square Error
|
||||
squared_errors = [(p - a) ** 2 for p, a in zip(predicted_values, actual_values)]
|
||||
rmse = math.sqrt(sum(squared_errors) / len(squared_errors))
|
||||
|
||||
# R² Score (coefficient of determination)
|
||||
mean_actual = sum(actual_values) / len(actual_values)
|
||||
ss_total = sum((a - mean_actual) ** 2 for a in actual_values)
|
||||
ss_residual = sum((a - p) ** 2 for a, p in zip(actual_values, predicted_values))
|
||||
r2_score = 1 - (ss_residual / ss_total) if ss_total > 0 else 0.0
|
||||
|
||||
# Accuracy percentage (100% - MAPE)
|
||||
accuracy_percentage = max(0.0, 100.0 - mape)
|
||||
|
||||
return {
|
||||
"mae": round(mae, 2),
|
||||
"mape": round(mape, 2),
|
||||
"rmse": round(rmse, 2),
|
||||
"r2_score": round(r2_score, 4),
|
||||
"accuracy_percentage": round(accuracy_percentage, 2),
|
||||
"total_predicted": round(sum(predicted_values), 2),
|
||||
"total_actual": round(sum(actual_values), 2)
|
||||
}
|
||||
|
||||
def _calculate_metrics_by_dimension(
|
||||
self,
|
||||
forecasts: List[Dict[str, Any]],
|
||||
dimension_key: str
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""Calculate metrics grouped by a dimension (product_id or location)"""
|
||||
dimension_groups = {}
|
||||
|
||||
# Group forecasts by dimension
|
||||
for forecast in forecasts:
|
||||
key = str(forecast[dimension_key])
|
||||
if key not in dimension_groups:
|
||||
dimension_groups[key] = []
|
||||
dimension_groups[key].append(forecast)
|
||||
|
||||
# Calculate metrics for each group
|
||||
metrics_by_dimension = {}
|
||||
for key, group_forecasts in dimension_groups.items():
|
||||
metrics = self._calculate_overall_metrics(group_forecasts)
|
||||
metrics_by_dimension[key] = {
|
||||
"count": len(group_forecasts),
|
||||
"mae": metrics["mae"],
|
||||
"mape": metrics["mape"],
|
||||
"rmse": metrics["rmse"],
|
||||
"accuracy_percentage": metrics["accuracy_percentage"],
|
||||
"total_predicted": metrics["total_predicted"],
|
||||
"total_actual": metrics["total_actual"]
|
||||
}
|
||||
|
||||
return metrics_by_dimension
|
||||
|
||||
async def get_validation_run(self, validation_run_id: uuid.UUID) -> Optional[ValidationRun]:
|
||||
"""Get a validation run by ID"""
|
||||
try:
|
||||
query = select(ValidationRun).where(ValidationRun.id == validation_run_id)
|
||||
result = await self.db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Failed to get validation run", validation_run_id=validation_run_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to get validation run: {str(e)}")
|
||||
|
||||
async def get_validation_runs_by_tenant(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
limit: int = 50,
|
||||
skip: int = 0
|
||||
) -> List[ValidationRun]:
|
||||
"""Get validation runs for a tenant"""
|
||||
try:
|
||||
query = (
|
||||
select(ValidationRun)
|
||||
.where(ValidationRun.tenant_id == tenant_id)
|
||||
.order_by(ValidationRun.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(skip)
|
||||
)
|
||||
result = await self.db.execute(query)
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error("Failed to get validation runs", tenant_id=tenant_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to get validation runs: {str(e)}")
|
||||
|
||||
async def get_accuracy_trends(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get accuracy trends over time"""
|
||||
try:
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
|
||||
query = (
|
||||
select(ValidationRun)
|
||||
.where(
|
||||
and_(
|
||||
ValidationRun.tenant_id == tenant_id,
|
||||
ValidationRun.status == "completed",
|
||||
ValidationRun.created_at >= start_date
|
||||
)
|
||||
)
|
||||
.order_by(ValidationRun.created_at)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
runs = result.scalars().all()
|
||||
|
||||
if not runs:
|
||||
return {
|
||||
"period_days": days,
|
||||
"total_runs": 0,
|
||||
"trends": []
|
||||
}
|
||||
|
||||
trends = [
|
||||
{
|
||||
"date": run.validation_start_date.isoformat(),
|
||||
"mae": run.overall_mae,
|
||||
"mape": run.overall_mape,
|
||||
"rmse": run.overall_rmse,
|
||||
"accuracy_percentage": run.overall_accuracy_percentage,
|
||||
"forecasts_evaluated": run.total_forecasts_evaluated,
|
||||
"forecasts_with_actuals": run.forecasts_with_actuals
|
||||
}
|
||||
for run in runs
|
||||
]
|
||||
|
||||
# Calculate averages
|
||||
valid_runs = [r for r in runs if r.overall_mape is not None]
|
||||
avg_mape = sum(r.overall_mape for r in valid_runs) / len(valid_runs) if valid_runs else None
|
||||
avg_accuracy = sum(r.overall_accuracy_percentage for r in valid_runs) / len(valid_runs) if valid_runs else None
|
||||
|
||||
return {
|
||||
"period_days": days,
|
||||
"total_runs": len(runs),
|
||||
"average_mape": round(avg_mape, 2) if avg_mape else None,
|
||||
"average_accuracy": round(avg_accuracy, 2) if avg_accuracy else None,
|
||||
"trends": trends
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get accuracy trends", tenant_id=tenant_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to get accuracy trends: {str(e)}")
|
||||
Reference in New Issue
Block a user