Improve backend
This commit is contained in:
346
services/forecasting/app/api/validation.py
Normal file
346
services/forecasting/app/api/validation.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/validation.py
|
||||
# ================================================================
|
||||
"""
|
||||
Validation API - Forecast validation endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
||||
from typing import Dict, Any, List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.services.validation_service import ValidationService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["validation"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Schemas
|
||||
# ================================================================
|
||||
|
||||
class ValidationRequest(BaseModel):
|
||||
"""Request model for validation"""
|
||||
start_date: datetime = Field(..., description="Start date for validation period")
|
||||
end_date: datetime = Field(..., description="End date for validation period")
|
||||
orchestration_run_id: Optional[UUID] = Field(None, description="Optional orchestration run ID")
|
||||
triggered_by: str = Field(default="manual", description="Trigger source")
|
||||
|
||||
|
||||
class ValidationResponse(BaseModel):
|
||||
"""Response model for validation results"""
|
||||
validation_run_id: str
|
||||
status: str
|
||||
forecasts_evaluated: int
|
||||
forecasts_with_actuals: int
|
||||
forecasts_without_actuals: int
|
||||
metrics_created: int
|
||||
overall_metrics: Optional[Dict[str, float]] = None
|
||||
total_predicted_demand: Optional[float] = None
|
||||
total_actual_demand: Optional[float] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class ValidationRunResponse(BaseModel):
|
||||
"""Response model for validation run details"""
|
||||
id: str
|
||||
tenant_id: str
|
||||
orchestration_run_id: Optional[str]
|
||||
validation_start_date: str
|
||||
validation_end_date: str
|
||||
started_at: str
|
||||
completed_at: Optional[str]
|
||||
duration_seconds: Optional[float]
|
||||
status: str
|
||||
total_forecasts_evaluated: int
|
||||
forecasts_with_actuals: int
|
||||
forecasts_without_actuals: int
|
||||
overall_mae: Optional[float]
|
||||
overall_mape: Optional[float]
|
||||
overall_rmse: Optional[float]
|
||||
overall_r2_score: Optional[float]
|
||||
overall_accuracy_percentage: Optional[float]
|
||||
total_predicted_demand: float
|
||||
total_actual_demand: float
|
||||
metrics_by_product: Optional[Dict[str, Any]]
|
||||
metrics_by_location: Optional[Dict[str, Any]]
|
||||
metrics_records_created: int
|
||||
error_message: Optional[str]
|
||||
triggered_by: str
|
||||
execution_mode: str
|
||||
|
||||
|
||||
class AccuracyTrendResponse(BaseModel):
|
||||
"""Response model for accuracy trends"""
|
||||
period_days: int
|
||||
total_runs: int
|
||||
average_mape: Optional[float]
|
||||
average_accuracy: Optional[float]
|
||||
trends: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/validate-date-range"),
|
||||
response_model=ValidationResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def validate_date_range(
|
||||
validation_request: ValidationRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Validate forecasts against actual sales for a date range
|
||||
|
||||
This endpoint:
|
||||
- Fetches forecasts for the specified date range
|
||||
- Retrieves corresponding actual sales data
|
||||
- Calculates accuracy metrics (MAE, MAPE, RMSE, R², accuracy %)
|
||||
- Stores performance metrics in the database
|
||||
- Returns validation summary
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting date range validation",
|
||||
tenant_id=tenant_id,
|
||||
start_date=validation_request.start_date.isoformat(),
|
||||
end_date=validation_request.end_date.isoformat(),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
result = await validation_service.validate_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=validation_request.start_date,
|
||||
end_date=validation_request.end_date,
|
||||
orchestration_run_id=validation_request.orchestration_run_id,
|
||||
triggered_by=validation_request.triggered_by
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Date range validation completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=result.get("validation_run_id"),
|
||||
forecasts_evaluated=result.get("forecasts_evaluated")
|
||||
)
|
||||
|
||||
return ValidationResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to validate date range",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to validate forecasts: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/validate-yesterday"),
|
||||
response_model=ValidationResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def validate_yesterday(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
orchestration_run_id: Optional[UUID] = Query(None, description="Optional orchestration run ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Validate yesterday's forecasts against actual sales
|
||||
|
||||
Convenience endpoint for validating the most recent day's forecasts.
|
||||
This is typically called by the orchestrator as part of the daily workflow.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting yesterday validation",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
result = await validation_service.validate_yesterday(
|
||||
tenant_id=tenant_id,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
triggered_by="manual"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Yesterday validation completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=result.get("validation_run_id"),
|
||||
forecasts_evaluated=result.get("forecasts_evaluated")
|
||||
)
|
||||
|
||||
return ValidationResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to validate yesterday",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to validate yesterday's forecasts: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/runs/{validation_run_id}"),
|
||||
response_model=ValidationRunResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_validation_run(
|
||||
validation_run_id: UUID = Path(..., description="Validation run ID"),
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get details of a specific validation run
|
||||
|
||||
Returns complete information about a validation execution including:
|
||||
- Summary statistics
|
||||
- Overall accuracy metrics
|
||||
- Breakdown by product and location
|
||||
- Execution metadata
|
||||
"""
|
||||
try:
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
validation_run = await validation_service.get_validation_run(validation_run_id)
|
||||
|
||||
if not validation_run:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Validation run {validation_run_id} not found"
|
||||
)
|
||||
|
||||
if validation_run.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this validation run"
|
||||
)
|
||||
|
||||
return ValidationRunResponse(**validation_run.to_dict())
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get validation run",
|
||||
validation_run_id=validation_run_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get validation run: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/runs"),
|
||||
response_model=List[ValidationRunResponse],
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_validation_runs(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
limit: int = Query(50, ge=1, le=100, description="Number of records to return"),
|
||||
skip: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get validation runs for a tenant
|
||||
|
||||
Returns a list of validation executions with pagination support.
|
||||
"""
|
||||
try:
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
runs = await validation_service.get_validation_runs_by_tenant(
|
||||
tenant_id=tenant_id,
|
||||
limit=limit,
|
||||
skip=skip
|
||||
)
|
||||
|
||||
return [ValidationRunResponse(**run.to_dict()) for run in runs]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get validation runs",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get validation runs: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/trends"),
|
||||
response_model=AccuracyTrendResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_accuracy_trends(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
days: int = Query(30, ge=1, le=365, description="Number of days to analyze"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get accuracy trends over time
|
||||
|
||||
Returns validation accuracy metrics over the specified time period.
|
||||
Useful for monitoring model performance degradation and improvement.
|
||||
"""
|
||||
try:
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
trends = await validation_service.get_accuracy_trends(
|
||||
tenant_id=tenant_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
return AccuracyTrendResponse(**trends)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get accuracy trends",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get accuracy trends: {str(e)}"
|
||||
)
|
||||
Reference in New Issue
Block a user