Improve backend
This commit is contained in:
@@ -6,10 +6,20 @@ HTTP endpoints for demand forecasting and prediction operations
|
||||
from .forecasts import router as forecasts_router
|
||||
from .forecasting_operations import router as forecasting_operations_router
|
||||
from .analytics import router as analytics_router
|
||||
from .validation import router as validation_router
|
||||
from .historical_validation import router as historical_validation_router
|
||||
from .webhooks import router as webhooks_router
|
||||
from .performance_monitoring import router as performance_monitoring_router
|
||||
from .retraining import router as retraining_router
|
||||
|
||||
|
||||
__all__ = [
|
||||
"forecasts_router",
|
||||
"forecasting_operations_router",
|
||||
"analytics_router",
|
||||
"validation_router",
|
||||
"historical_validation_router",
|
||||
"webhooks_router",
|
||||
"performance_monitoring_router",
|
||||
"retraining_router",
|
||||
]
|
||||
304
services/forecasting/app/api/historical_validation.py
Normal file
304
services/forecasting/app/api/historical_validation.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/historical_validation.py
|
||||
# ================================================================
|
||||
"""
|
||||
Historical Validation API - Backfill validation for late-arriving sales data
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
||||
from typing import Dict, Any, List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import date
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.services.historical_validation_service import HistoricalValidationService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["historical-validation"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Schemas
|
||||
# ================================================================
|
||||
|
||||
class DetectGapsRequest(BaseModel):
|
||||
"""Request model for gap detection"""
|
||||
lookback_days: int = Field(default=90, ge=1, le=365, description="Days to look back")
|
||||
|
||||
|
||||
class BackfillRequest(BaseModel):
|
||||
"""Request model for manual backfill"""
|
||||
start_date: date = Field(..., description="Start date for backfill")
|
||||
end_date: date = Field(..., description="End date for backfill")
|
||||
|
||||
|
||||
class SalesDataUpdateRequest(BaseModel):
|
||||
"""Request model for registering sales data update"""
|
||||
start_date: date = Field(..., description="Start date of updated data")
|
||||
end_date: date = Field(..., description="End date of updated data")
|
||||
records_affected: int = Field(..., ge=0, description="Number of records affected")
|
||||
update_source: str = Field(default="import", description="Source of update")
|
||||
import_job_id: Optional[str] = Field(None, description="Import job ID if applicable")
|
||||
auto_trigger_validation: bool = Field(default=True, description="Auto-trigger validation")
|
||||
|
||||
|
||||
class AutoBackfillRequest(BaseModel):
|
||||
"""Request model for automatic backfill"""
|
||||
lookback_days: int = Field(default=90, ge=1, le=365, description="Days to look back")
|
||||
max_gaps_to_process: int = Field(default=10, ge=1, le=50, description="Max gaps to process")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/detect-gaps"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def detect_validation_gaps(
|
||||
request: DetectGapsRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Detect date ranges where forecasts exist but haven't been validated yet
|
||||
|
||||
Returns list of gap periods that need validation backfill.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Detecting validation gaps",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=request.lookback_days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
gaps = await service.detect_validation_gaps(
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=request.lookback_days
|
||||
)
|
||||
|
||||
return {
|
||||
"gaps_found": len(gaps),
|
||||
"lookback_days": request.lookback_days,
|
||||
"gaps": [
|
||||
{
|
||||
"start_date": gap["start_date"].isoformat(),
|
||||
"end_date": gap["end_date"].isoformat(),
|
||||
"days_count": gap["days_count"]
|
||||
}
|
||||
for gap in gaps
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to detect validation gaps",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to detect validation gaps: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/backfill"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def backfill_validation(
|
||||
request: BackfillRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Manually trigger validation backfill for a specific date range
|
||||
|
||||
Validates forecasts against sales data for historical periods.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Manual validation backfill requested",
|
||||
tenant_id=tenant_id,
|
||||
start_date=request.start_date.isoformat(),
|
||||
end_date=request.end_date.isoformat(),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
result = await service.backfill_validation(
|
||||
tenant_id=tenant_id,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
triggered_by="manual"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to backfill validation",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to backfill validation: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/auto-backfill"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def auto_backfill_validation_gaps(
|
||||
request: AutoBackfillRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Automatically detect and backfill validation gaps
|
||||
|
||||
Finds all date ranges with missing validations and processes them.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Auto backfill requested",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=request.lookback_days,
|
||||
max_gaps=request.max_gaps_to_process,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
result = await service.auto_backfill_gaps(
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=request.lookback_days,
|
||||
max_gaps_to_process=request.max_gaps_to_process
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to auto backfill",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to auto backfill: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/register-sales-update"),
|
||||
status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def register_sales_data_update(
|
||||
request: SalesDataUpdateRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Register a sales data update and optionally trigger validation
|
||||
|
||||
Call this endpoint after importing historical sales data to automatically
|
||||
trigger validation for the affected date range.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Registering sales data update",
|
||||
tenant_id=tenant_id,
|
||||
date_range=f"{request.start_date} to {request.end_date}",
|
||||
records_affected=request.records_affected,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
result = await service.register_sales_data_update(
|
||||
tenant_id=tenant_id,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
records_affected=request.records_affected,
|
||||
update_source=request.update_source,
|
||||
import_job_id=request.import_job_id,
|
||||
auto_trigger_validation=request.auto_trigger_validation
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to register sales data update",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to register sales data update: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/pending"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_pending_validations(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
limit: int = Query(50, ge=1, le=100, description="Number of records to return"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get pending sales data updates awaiting validation
|
||||
|
||||
Returns list of sales data updates that have been registered
|
||||
but not yet validated.
|
||||
"""
|
||||
try:
|
||||
service = HistoricalValidationService(db)
|
||||
|
||||
pending = await service.get_pending_validations(
|
||||
tenant_id=tenant_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"pending_count": len(pending),
|
||||
"pending_validations": [record.to_dict() for record in pending]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get pending validations",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get pending validations: {str(e)}"
|
||||
)
|
||||
287
services/forecasting/app/api/performance_monitoring.py
Normal file
287
services/forecasting/app/api/performance_monitoring.py
Normal file
@@ -0,0 +1,287 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/performance_monitoring.py
|
||||
# ================================================================
|
||||
"""
|
||||
Performance Monitoring API - Track and analyze forecast accuracy over time
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
||||
from typing import Dict, Any
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.services.performance_monitoring_service import PerformanceMonitoringService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["performance-monitoring"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Schemas
|
||||
# ================================================================
|
||||
|
||||
class AccuracySummaryRequest(BaseModel):
|
||||
"""Request model for accuracy summary"""
|
||||
days: int = Field(default=30, ge=1, le=365, description="Analysis period in days")
|
||||
|
||||
|
||||
class DegradationAnalysisRequest(BaseModel):
|
||||
"""Request model for degradation analysis"""
|
||||
lookback_days: int = Field(default=30, ge=7, le=365, description="Days to analyze")
|
||||
|
||||
|
||||
class ModelAgeCheckRequest(BaseModel):
|
||||
"""Request model for model age check"""
|
||||
max_age_days: int = Field(default=30, ge=1, le=90, description="Max acceptable model age")
|
||||
|
||||
|
||||
class PerformanceReportRequest(BaseModel):
|
||||
"""Request model for comprehensive performance report"""
|
||||
days: int = Field(default=30, ge=1, le=365, description="Analysis period in days")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("monitoring/accuracy-summary"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_accuracy_summary(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
days: int = Query(30, ge=1, le=365, description="Analysis period in days"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get forecast accuracy summary for recent period
|
||||
|
||||
Returns overall metrics, validation coverage, and health status.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Getting accuracy summary",
|
||||
tenant_id=tenant_id,
|
||||
days=days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
summary = await service.get_accuracy_summary(
|
||||
tenant_id=tenant_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get accuracy summary",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get accuracy summary: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("monitoring/degradation-analysis"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def analyze_performance_degradation(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
lookback_days: int = Query(30, ge=7, le=365, description="Days to analyze"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Detect if forecast performance is degrading over time
|
||||
|
||||
Compares first half vs second half of period and identifies poor performers.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Analyzing performance degradation",
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
analysis = await service.detect_performance_degradation(
|
||||
tenant_id=tenant_id,
|
||||
lookback_days=lookback_days
|
||||
)
|
||||
|
||||
return analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to analyze degradation",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to analyze degradation: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("monitoring/model-age"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def check_model_age(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
max_age_days: int = Query(30, ge=1, le=90, description="Max acceptable model age"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Check if models are outdated and need retraining
|
||||
|
||||
Returns models in use and identifies those needing updates.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Checking model age",
|
||||
tenant_id=tenant_id,
|
||||
max_age_days=max_age_days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
analysis = await service.check_model_age(
|
||||
tenant_id=tenant_id,
|
||||
max_age_days=max_age_days
|
||||
)
|
||||
|
||||
return analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to check model age",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to check model age: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("monitoring/performance-report"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def generate_performance_report(
|
||||
request: PerformanceReportRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Generate comprehensive performance report
|
||||
|
||||
Combines accuracy summary, degradation analysis, and model age check
|
||||
with actionable recommendations.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Generating performance report",
|
||||
tenant_id=tenant_id,
|
||||
days=request.days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
report = await service.generate_performance_report(
|
||||
tenant_id=tenant_id,
|
||||
days=request.days
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to generate performance report",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to generate performance report: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("monitoring/health"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_health_status(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get quick health status for dashboards
|
||||
|
||||
Returns simplified health metrics for UI display.
|
||||
"""
|
||||
try:
|
||||
service = PerformanceMonitoringService(db)
|
||||
|
||||
# Get 7-day summary for quick health check
|
||||
summary = await service.get_accuracy_summary(
|
||||
tenant_id=tenant_id,
|
||||
days=7
|
||||
)
|
||||
|
||||
if summary.get("status") == "no_data":
|
||||
return {
|
||||
"status": "unknown",
|
||||
"message": "No recent validation data available",
|
||||
"health_status": "unknown"
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"health_status": summary.get("health_status"),
|
||||
"current_mape": summary["average_metrics"].get("mape"),
|
||||
"accuracy_percentage": summary["average_metrics"].get("accuracy_percentage"),
|
||||
"validation_coverage": summary.get("coverage_percentage"),
|
||||
"last_7_days": {
|
||||
"validation_runs": summary.get("validation_runs"),
|
||||
"forecasts_evaluated": summary.get("total_forecasts_evaluated")
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get health status",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get health status: {str(e)}"
|
||||
)
|
||||
297
services/forecasting/app/api/retraining.py
Normal file
297
services/forecasting/app/api/retraining.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/retraining.py
|
||||
# ================================================================
|
||||
"""
|
||||
Retraining API - Trigger and manage model retraining based on performance
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
||||
from typing import Dict, Any, List
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.services.retraining_trigger_service import RetrainingTriggerService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["retraining"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Schemas
|
||||
# ================================================================
|
||||
|
||||
class EvaluateRetrainingRequest(BaseModel):
|
||||
"""Request model for retraining evaluation"""
|
||||
auto_trigger: bool = Field(
|
||||
default=False,
|
||||
description="Automatically trigger retraining for poor performers"
|
||||
)
|
||||
|
||||
|
||||
class TriggerProductRetrainingRequest(BaseModel):
|
||||
"""Request model for single product retraining"""
|
||||
inventory_product_id: UUID = Field(..., description="Product to retrain")
|
||||
reason: str = Field(..., description="Reason for retraining")
|
||||
priority: str = Field(
|
||||
default="normal",
|
||||
description="Priority level: low, normal, high"
|
||||
)
|
||||
|
||||
|
||||
class TriggerBulkRetrainingRequest(BaseModel):
|
||||
"""Request model for bulk retraining"""
|
||||
product_ids: List[UUID] = Field(..., description="List of products to retrain")
|
||||
reason: str = Field(
|
||||
default="Bulk retraining requested",
|
||||
description="Reason for bulk retraining"
|
||||
)
|
||||
|
||||
|
||||
class ScheduledRetrainingCheckRequest(BaseModel):
|
||||
"""Request model for scheduled retraining check"""
|
||||
max_model_age_days: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=90,
|
||||
description="Maximum acceptable model age"
|
||||
)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/evaluate"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def evaluate_retraining_needs(
|
||||
request: EvaluateRetrainingRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Evaluate performance and optionally trigger retraining
|
||||
|
||||
Analyzes 30-day performance and identifies products needing retraining.
|
||||
If auto_trigger=true, automatically triggers retraining for poor performers.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Evaluating retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=request.auto_trigger,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service.evaluate_and_trigger_retraining(
|
||||
tenant_id=tenant_id,
|
||||
auto_trigger=request.auto_trigger
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to evaluate retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to evaluate retraining: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/trigger-product"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def trigger_product_retraining(
|
||||
request: TriggerProductRetrainingRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Trigger retraining for a specific product
|
||||
|
||||
Manually trigger model retraining for a single product.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=request.inventory_product_id,
|
||||
reason=request.reason,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service._trigger_product_retraining(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
reason=request.reason,
|
||||
priority=request.priority
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger product retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_id=request.inventory_product_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to trigger retraining: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/trigger-bulk"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def trigger_bulk_retraining(
|
||||
request: TriggerBulkRetrainingRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Trigger retraining for multiple products
|
||||
|
||||
Bulk retraining operation for multiple products at once.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Triggering bulk retraining",
|
||||
tenant_id=tenant_id,
|
||||
product_count=len(request.product_ids),
|
||||
reason=request.reason,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service.trigger_bulk_retraining(
|
||||
tenant_id=tenant_id,
|
||||
product_ids=request.product_ids,
|
||||
reason=request.reason
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to trigger bulk retraining",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to trigger bulk retraining: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("retraining/recommendations"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_retraining_recommendations(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get retraining recommendations without triggering
|
||||
|
||||
Returns recommendations for manual review and decision-making.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Getting retraining recommendations",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
recommendations = await service.get_retraining_recommendations(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return recommendations
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get recommendations",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get recommendations: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("retraining/check-scheduled"),
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner'])
|
||||
async def check_scheduled_retraining(
|
||||
request: ScheduledRetrainingCheckRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Check for models needing scheduled retraining based on age
|
||||
|
||||
Identifies models that haven't been updated in max_model_age_days.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Checking scheduled retraining needs",
|
||||
tenant_id=tenant_id,
|
||||
max_model_age_days=request.max_model_age_days,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
service = RetrainingTriggerService(db)
|
||||
|
||||
result = await service.check_and_trigger_scheduled_retraining(
|
||||
tenant_id=tenant_id,
|
||||
max_model_age_days=request.max_model_age_days
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to check scheduled retraining",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to check scheduled retraining: {str(e)}"
|
||||
)
|
||||
346
services/forecasting/app/api/validation.py
Normal file
346
services/forecasting/app/api/validation.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/validation.py
|
||||
# ================================================================
|
||||
"""
|
||||
Validation API - Forecast validation endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status
|
||||
from typing import Dict, Any, List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.services.validation_service import ValidationService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import require_user_role
|
||||
from shared.routing import RouteBuilder
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["validation"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request/Response Schemas
|
||||
# ================================================================
|
||||
|
||||
class ValidationRequest(BaseModel):
|
||||
"""Request model for validation"""
|
||||
start_date: datetime = Field(..., description="Start date for validation period")
|
||||
end_date: datetime = Field(..., description="End date for validation period")
|
||||
orchestration_run_id: Optional[UUID] = Field(None, description="Optional orchestration run ID")
|
||||
triggered_by: str = Field(default="manual", description="Trigger source")
|
||||
|
||||
|
||||
class ValidationResponse(BaseModel):
|
||||
"""Response model for validation results"""
|
||||
validation_run_id: str
|
||||
status: str
|
||||
forecasts_evaluated: int
|
||||
forecasts_with_actuals: int
|
||||
forecasts_without_actuals: int
|
||||
metrics_created: int
|
||||
overall_metrics: Optional[Dict[str, float]] = None
|
||||
total_predicted_demand: Optional[float] = None
|
||||
total_actual_demand: Optional[float] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class ValidationRunResponse(BaseModel):
|
||||
"""Response model for validation run details"""
|
||||
id: str
|
||||
tenant_id: str
|
||||
orchestration_run_id: Optional[str]
|
||||
validation_start_date: str
|
||||
validation_end_date: str
|
||||
started_at: str
|
||||
completed_at: Optional[str]
|
||||
duration_seconds: Optional[float]
|
||||
status: str
|
||||
total_forecasts_evaluated: int
|
||||
forecasts_with_actuals: int
|
||||
forecasts_without_actuals: int
|
||||
overall_mae: Optional[float]
|
||||
overall_mape: Optional[float]
|
||||
overall_rmse: Optional[float]
|
||||
overall_r2_score: Optional[float]
|
||||
overall_accuracy_percentage: Optional[float]
|
||||
total_predicted_demand: float
|
||||
total_actual_demand: float
|
||||
metrics_by_product: Optional[Dict[str, Any]]
|
||||
metrics_by_location: Optional[Dict[str, Any]]
|
||||
metrics_records_created: int
|
||||
error_message: Optional[str]
|
||||
triggered_by: str
|
||||
execution_mode: str
|
||||
|
||||
|
||||
class AccuracyTrendResponse(BaseModel):
|
||||
"""Response model for accuracy trends"""
|
||||
period_days: int
|
||||
total_runs: int
|
||||
average_mape: Optional[float]
|
||||
average_accuracy: Optional[float]
|
||||
trends: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/validate-date-range"),
|
||||
response_model=ValidationResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def validate_date_range(
|
||||
validation_request: ValidationRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Validate forecasts against actual sales for a date range
|
||||
|
||||
This endpoint:
|
||||
- Fetches forecasts for the specified date range
|
||||
- Retrieves corresponding actual sales data
|
||||
- Calculates accuracy metrics (MAE, MAPE, RMSE, R², accuracy %)
|
||||
- Stores performance metrics in the database
|
||||
- Returns validation summary
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting date range validation",
|
||||
tenant_id=tenant_id,
|
||||
start_date=validation_request.start_date.isoformat(),
|
||||
end_date=validation_request.end_date.isoformat(),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
result = await validation_service.validate_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=validation_request.start_date,
|
||||
end_date=validation_request.end_date,
|
||||
orchestration_run_id=validation_request.orchestration_run_id,
|
||||
triggered_by=validation_request.triggered_by
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Date range validation completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=result.get("validation_run_id"),
|
||||
forecasts_evaluated=result.get("forecasts_evaluated")
|
||||
)
|
||||
|
||||
return ValidationResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to validate date range",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to validate forecasts: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
route_builder.build_base_route("validation/validate-yesterday"),
|
||||
response_model=ValidationResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def validate_yesterday(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
orchestration_run_id: Optional[UUID] = Query(None, description="Optional orchestration run ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Validate yesterday's forecasts against actual sales
|
||||
|
||||
Convenience endpoint for validating the most recent day's forecasts.
|
||||
This is typically called by the orchestrator as part of the daily workflow.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting yesterday validation",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
result = await validation_service.validate_yesterday(
|
||||
tenant_id=tenant_id,
|
||||
orchestration_run_id=orchestration_run_id,
|
||||
triggered_by="manual"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Yesterday validation completed",
|
||||
tenant_id=tenant_id,
|
||||
validation_run_id=result.get("validation_run_id"),
|
||||
forecasts_evaluated=result.get("forecasts_evaluated")
|
||||
)
|
||||
|
||||
return ValidationResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to validate yesterday",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to validate yesterday's forecasts: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/runs/{validation_run_id}"),
|
||||
response_model=ValidationRunResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_validation_run(
|
||||
validation_run_id: UUID = Path(..., description="Validation run ID"),
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get details of a specific validation run
|
||||
|
||||
Returns complete information about a validation execution including:
|
||||
- Summary statistics
|
||||
- Overall accuracy metrics
|
||||
- Breakdown by product and location
|
||||
- Execution metadata
|
||||
"""
|
||||
try:
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
validation_run = await validation_service.get_validation_run(validation_run_id)
|
||||
|
||||
if not validation_run:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Validation run {validation_run_id} not found"
|
||||
)
|
||||
|
||||
if validation_run.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this validation run"
|
||||
)
|
||||
|
||||
return ValidationRunResponse(**validation_run.to_dict())
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get validation run",
|
||||
validation_run_id=validation_run_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get validation run: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/runs"),
|
||||
response_model=List[ValidationRunResponse],
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_validation_runs(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
limit: int = Query(50, ge=1, le=100, description="Number of records to return"),
|
||||
skip: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get validation runs for a tenant
|
||||
|
||||
Returns a list of validation executions with pagination support.
|
||||
"""
|
||||
try:
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
runs = await validation_service.get_validation_runs_by_tenant(
|
||||
tenant_id=tenant_id,
|
||||
limit=limit,
|
||||
skip=skip
|
||||
)
|
||||
|
||||
return [ValidationRunResponse(**run.to_dict()) for run in runs]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get validation runs",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get validation runs: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
route_builder.build_base_route("validation/trends"),
|
||||
response_model=AccuracyTrendResponse,
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
@require_user_role(['admin', 'owner', 'member'])
|
||||
async def get_accuracy_trends(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
days: int = Query(30, ge=1, le=365, description="Number of days to analyze"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get accuracy trends over time
|
||||
|
||||
Returns validation accuracy metrics over the specified time period.
|
||||
Useful for monitoring model performance degradation and improvement.
|
||||
"""
|
||||
try:
|
||||
validation_service = ValidationService(db)
|
||||
|
||||
trends = await validation_service.get_accuracy_trends(
|
||||
tenant_id=tenant_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
return AccuracyTrendResponse(**trends)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get accuracy trends",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get accuracy trends: {str(e)}"
|
||||
)
|
||||
174
services/forecasting/app/api/webhooks.py
Normal file
174
services/forecasting/app/api/webhooks.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/webhooks.py
|
||||
# ================================================================
|
||||
"""
|
||||
Webhooks API - Receive events from other services
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status, Header
|
||||
from typing import Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
from datetime import date
|
||||
import structlog
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from app.jobs.sales_data_listener import (
|
||||
handle_sales_import_completion,
|
||||
handle_pos_sync_completion
|
||||
)
|
||||
from shared.routing import RouteBuilder
|
||||
|
||||
route_builder = RouteBuilder('forecasting')
|
||||
router = APIRouter(tags=["webhooks"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Request Schemas
|
||||
# ================================================================
|
||||
|
||||
class SalesImportWebhook(BaseModel):
|
||||
"""Webhook payload for sales data import completion"""
|
||||
tenant_id: UUID = Field(..., description="Tenant ID")
|
||||
import_job_id: str = Field(..., description="Import job ID")
|
||||
start_date: date = Field(..., description="Start date of imported data")
|
||||
end_date: date = Field(..., description="End date of imported data")
|
||||
records_count: int = Field(..., ge=0, description="Number of records imported")
|
||||
import_source: str = Field(default="import", description="Source of import")
|
||||
|
||||
|
||||
class POSSyncWebhook(BaseModel):
|
||||
"""Webhook payload for POS sync completion"""
|
||||
tenant_id: UUID = Field(..., description="Tenant ID")
|
||||
sync_log_id: str = Field(..., description="POS sync log ID")
|
||||
sync_date: date = Field(..., description="Date of synced data")
|
||||
records_synced: int = Field(..., ge=0, description="Number of records synced")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.post(
|
||||
"/webhooks/sales-import-completed",
|
||||
status_code=status.HTTP_202_ACCEPTED
|
||||
)
|
||||
async def sales_import_completed_webhook(
|
||||
payload: SalesImportWebhook,
|
||||
x_webhook_signature: Optional[str] = Header(None, description="Webhook signature for verification")
|
||||
):
|
||||
"""
|
||||
Webhook endpoint for sales data import completion
|
||||
|
||||
Called by the sales service when a data import completes.
|
||||
Triggers validation backfill for the imported date range.
|
||||
|
||||
Note: In production, this should verify the webhook signature
|
||||
to ensure the request comes from a trusted source.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Received sales import completion webhook",
|
||||
tenant_id=payload.tenant_id,
|
||||
import_job_id=payload.import_job_id,
|
||||
date_range=f"{payload.start_date} to {payload.end_date}"
|
||||
)
|
||||
|
||||
# In production, verify webhook signature here
|
||||
# if not verify_webhook_signature(x_webhook_signature, payload):
|
||||
# raise HTTPException(status_code=401, detail="Invalid webhook signature")
|
||||
|
||||
# Handle the import completion asynchronously
|
||||
result = await handle_sales_import_completion(
|
||||
tenant_id=payload.tenant_id,
|
||||
import_job_id=payload.import_job_id,
|
||||
start_date=payload.start_date,
|
||||
end_date=payload.end_date,
|
||||
records_count=payload.records_count,
|
||||
import_source=payload.import_source
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "accepted",
|
||||
"message": "Sales import completion event received and processing",
|
||||
"result": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process sales import webhook",
|
||||
payload=payload.dict(),
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to process webhook: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/webhooks/pos-sync-completed",
|
||||
status_code=status.HTTP_202_ACCEPTED
|
||||
)
|
||||
async def pos_sync_completed_webhook(
|
||||
payload: POSSyncWebhook,
|
||||
x_webhook_signature: Optional[str] = Header(None, description="Webhook signature for verification")
|
||||
):
|
||||
"""
|
||||
Webhook endpoint for POS sync completion
|
||||
|
||||
Called by the POS service when data synchronization completes.
|
||||
Triggers validation for the synced date.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Received POS sync completion webhook",
|
||||
tenant_id=payload.tenant_id,
|
||||
sync_log_id=payload.sync_log_id,
|
||||
sync_date=payload.sync_date.isoformat()
|
||||
)
|
||||
|
||||
# In production, verify webhook signature here
|
||||
# if not verify_webhook_signature(x_webhook_signature, payload):
|
||||
# raise HTTPException(status_code=401, detail="Invalid webhook signature")
|
||||
|
||||
# Handle the sync completion
|
||||
result = await handle_pos_sync_completion(
|
||||
tenant_id=payload.tenant_id,
|
||||
sync_log_id=payload.sync_log_id,
|
||||
sync_date=payload.sync_date,
|
||||
records_synced=payload.records_synced
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "accepted",
|
||||
"message": "POS sync completion event received and processing",
|
||||
"result": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process POS sync webhook",
|
||||
payload=payload.dict(),
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to process webhook: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/webhooks/health",
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
||||
async def webhook_health_check():
|
||||
"""Health check endpoint for webhook receiver"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "forecasting-webhooks",
|
||||
"endpoints": [
|
||||
"/webhooks/sales-import-completed",
|
||||
"/webhooks/pos-sync-completed"
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user