347 lines
11 KiB
Python
347 lines
11 KiB
Python
# ================================================================
|
|
# services/forecasting/app/api/validation.py
|
|
# ================================================================
|
|
"""
|
|
Validation API - Forecast validation endpoints
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
|
from typing import Dict, Any, List, Optional
|
|
from uuid import UUID
|
|
from datetime import datetime, timedelta, timezone
|
|
import structlog
|
|
|
|
from pydantic import BaseModel, Field
|
|
from app.services.validation_service import ValidationService
|
|
from shared.auth.decorators import get_current_user_dep
|
|
from shared.auth.access_control import require_user_role
|
|
from shared.routing import RouteBuilder
|
|
from app.core.database import get_db
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
route_builder = RouteBuilder('forecasting')
|
|
router = APIRouter(tags=["validation"])
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
# ================================================================
|
|
# Request/Response Schemas
|
|
# ================================================================
|
|
|
|
class ValidationRequest(BaseModel):
|
|
"""Request model for validation"""
|
|
start_date: datetime = Field(..., description="Start date for validation period")
|
|
end_date: datetime = Field(..., description="End date for validation period")
|
|
orchestration_run_id: Optional[UUID] = Field(None, description="Optional orchestration run ID")
|
|
triggered_by: str = Field(default="manual", description="Trigger source")
|
|
|
|
|
|
class ValidationResponse(BaseModel):
|
|
"""Response model for validation results"""
|
|
validation_run_id: str
|
|
status: str
|
|
forecasts_evaluated: int
|
|
forecasts_with_actuals: int
|
|
forecasts_without_actuals: int
|
|
metrics_created: int
|
|
overall_metrics: Optional[Dict[str, float]] = None
|
|
total_predicted_demand: Optional[float] = None
|
|
total_actual_demand: Optional[float] = None
|
|
duration_seconds: Optional[float] = None
|
|
message: Optional[str] = None
|
|
|
|
|
|
class ValidationRunResponse(BaseModel):
|
|
"""Response model for validation run details"""
|
|
id: str
|
|
tenant_id: str
|
|
orchestration_run_id: Optional[str]
|
|
validation_start_date: str
|
|
validation_end_date: str
|
|
started_at: str
|
|
completed_at: Optional[str]
|
|
duration_seconds: Optional[float]
|
|
status: str
|
|
total_forecasts_evaluated: int
|
|
forecasts_with_actuals: int
|
|
forecasts_without_actuals: int
|
|
overall_mae: Optional[float]
|
|
overall_mape: Optional[float]
|
|
overall_rmse: Optional[float]
|
|
overall_r2_score: Optional[float]
|
|
overall_accuracy_percentage: Optional[float]
|
|
total_predicted_demand: float
|
|
total_actual_demand: float
|
|
metrics_by_product: Optional[Dict[str, Any]]
|
|
metrics_by_location: Optional[Dict[str, Any]]
|
|
metrics_records_created: int
|
|
error_message: Optional[str]
|
|
triggered_by: str
|
|
execution_mode: str
|
|
|
|
|
|
class AccuracyTrendResponse(BaseModel):
|
|
"""Response model for accuracy trends"""
|
|
period_days: int
|
|
total_runs: int
|
|
average_mape: Optional[float]
|
|
average_accuracy: Optional[float]
|
|
trends: List[Dict[str, Any]]
|
|
|
|
|
|
# ================================================================
|
|
# Endpoints
|
|
# ================================================================
|
|
|
|
@router.post(
|
|
route_builder.build_base_route("validation/validate-date-range"),
|
|
response_model=ValidationResponse,
|
|
status_code=status.HTTP_200_OK
|
|
)
|
|
@require_user_role(['admin', 'owner', 'member'])
|
|
async def validate_date_range(
|
|
validation_request: ValidationRequest,
|
|
tenant_id: UUID = Path(..., description="Tenant ID"),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Validate forecasts against actual sales for a date range
|
|
|
|
This endpoint:
|
|
- Fetches forecasts for the specified date range
|
|
- Retrieves corresponding actual sales data
|
|
- Calculates accuracy metrics (MAE, MAPE, RMSE, R², accuracy %)
|
|
- Stores performance metrics in the database
|
|
- Returns validation summary
|
|
"""
|
|
try:
|
|
logger.info(
|
|
"Starting date range validation",
|
|
tenant_id=tenant_id,
|
|
start_date=validation_request.start_date.isoformat(),
|
|
end_date=validation_request.end_date.isoformat(),
|
|
user_id=current_user.get("user_id")
|
|
)
|
|
|
|
validation_service = ValidationService(db)
|
|
|
|
result = await validation_service.validate_date_range(
|
|
tenant_id=tenant_id,
|
|
start_date=validation_request.start_date,
|
|
end_date=validation_request.end_date,
|
|
orchestration_run_id=validation_request.orchestration_run_id,
|
|
triggered_by=validation_request.triggered_by
|
|
)
|
|
|
|
logger.info(
|
|
"Date range validation completed",
|
|
tenant_id=tenant_id,
|
|
validation_run_id=result.get("validation_run_id"),
|
|
forecasts_evaluated=result.get("forecasts_evaluated")
|
|
)
|
|
|
|
return ValidationResponse(**result)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to validate date range",
|
|
tenant_id=tenant_id,
|
|
error=str(e),
|
|
error_type=type(e).__name__
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to validate forecasts: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_base_route("validation/validate-yesterday"),
|
|
response_model=ValidationResponse,
|
|
status_code=status.HTTP_200_OK
|
|
)
|
|
@require_user_role(['admin', 'owner', 'member'])
|
|
async def validate_yesterday(
|
|
tenant_id: UUID = Path(..., description="Tenant ID"),
|
|
orchestration_run_id: Optional[UUID] = Query(None, description="Optional orchestration run ID"),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Validate yesterday's forecasts against actual sales
|
|
|
|
Convenience endpoint for validating the most recent day's forecasts.
|
|
This is typically called by the orchestrator as part of the daily workflow.
|
|
"""
|
|
try:
|
|
logger.info(
|
|
"Starting yesterday validation",
|
|
tenant_id=tenant_id,
|
|
user_id=current_user.get("user_id")
|
|
)
|
|
|
|
validation_service = ValidationService(db)
|
|
|
|
result = await validation_service.validate_yesterday(
|
|
tenant_id=tenant_id,
|
|
orchestration_run_id=orchestration_run_id,
|
|
triggered_by="manual"
|
|
)
|
|
|
|
logger.info(
|
|
"Yesterday validation completed",
|
|
tenant_id=tenant_id,
|
|
validation_run_id=result.get("validation_run_id"),
|
|
forecasts_evaluated=result.get("forecasts_evaluated")
|
|
)
|
|
|
|
return ValidationResponse(**result)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to validate yesterday",
|
|
tenant_id=tenant_id,
|
|
error=str(e),
|
|
error_type=type(e).__name__
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to validate yesterday's forecasts: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get(
|
|
route_builder.build_base_route("validation/runs/{validation_run_id}"),
|
|
response_model=ValidationRunResponse,
|
|
status_code=status.HTTP_200_OK
|
|
)
|
|
@require_user_role(['admin', 'owner', 'member'])
|
|
async def get_validation_run(
|
|
validation_run_id: UUID = Path(..., description="Validation run ID"),
|
|
tenant_id: UUID = Path(..., description="Tenant ID"),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Get details of a specific validation run
|
|
|
|
Returns complete information about a validation execution including:
|
|
- Summary statistics
|
|
- Overall accuracy metrics
|
|
- Breakdown by product and location
|
|
- Execution metadata
|
|
"""
|
|
try:
|
|
validation_service = ValidationService(db)
|
|
|
|
validation_run = await validation_service.get_validation_run(validation_run_id)
|
|
|
|
if not validation_run:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Validation run {validation_run_id} not found"
|
|
)
|
|
|
|
if validation_run.tenant_id != tenant_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Access denied to this validation run"
|
|
)
|
|
|
|
return ValidationRunResponse(**validation_run.to_dict())
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to get validation run",
|
|
validation_run_id=validation_run_id,
|
|
error=str(e)
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to get validation run: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get(
|
|
route_builder.build_base_route("validation/runs"),
|
|
response_model=List[ValidationRunResponse],
|
|
status_code=status.HTTP_200_OK
|
|
)
|
|
@require_user_role(['admin', 'owner', 'member'])
|
|
async def get_validation_runs(
|
|
tenant_id: UUID = Path(..., description="Tenant ID"),
|
|
limit: int = Query(50, ge=1, le=100, description="Number of records to return"),
|
|
skip: int = Query(0, ge=0, description="Number of records to skip"),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Get validation runs for a tenant
|
|
|
|
Returns a list of validation executions with pagination support.
|
|
"""
|
|
try:
|
|
validation_service = ValidationService(db)
|
|
|
|
runs = await validation_service.get_validation_runs_by_tenant(
|
|
tenant_id=tenant_id,
|
|
limit=limit,
|
|
skip=skip
|
|
)
|
|
|
|
return [ValidationRunResponse(**run.to_dict()) for run in runs]
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to get validation runs",
|
|
tenant_id=tenant_id,
|
|
error=str(e)
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to get validation runs: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get(
|
|
route_builder.build_base_route("validation/trends"),
|
|
response_model=AccuracyTrendResponse,
|
|
status_code=status.HTTP_200_OK
|
|
)
|
|
@require_user_role(['admin', 'owner', 'member'])
|
|
async def get_accuracy_trends(
|
|
tenant_id: UUID = Path(..., description="Tenant ID"),
|
|
days: int = Query(30, ge=1, le=365, description="Number of days to analyze"),
|
|
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Get accuracy trends over time
|
|
|
|
Returns validation accuracy metrics over the specified time period.
|
|
Useful for monitoring model performance degradation and improvement.
|
|
"""
|
|
try:
|
|
validation_service = ValidationService(db)
|
|
|
|
trends = await validation_service.get_accuracy_trends(
|
|
tenant_id=tenant_id,
|
|
days=days
|
|
)
|
|
|
|
return AccuracyTrendResponse(**trends)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to get accuracy trends",
|
|
tenant_id=tenant_id,
|
|
error=str(e)
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to get accuracy trends: {str(e)}"
|
|
)
|