1010 lines
38 KiB
Python
1010 lines
38 KiB
Python
# services/forecasting/app/api/forecasting_operations.py
|
|
"""
|
|
Forecasting Operations API - Business operations for forecast generation and predictions
|
|
"""
|
|
|
|
import structlog
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Request, BackgroundTasks
|
|
from fastapi.responses import JSONResponse
|
|
from typing import List, Dict, Any, Optional
|
|
from datetime import date, datetime, timezone
|
|
import uuid
|
|
|
|
from app.services.forecasting_service import EnhancedForecastingService
|
|
from app.services.prediction_service import PredictionService
|
|
from app.services.forecast_cache import get_forecast_cache_service
|
|
from app.schemas.forecasts import (
|
|
ForecastRequest, ForecastResponse, BatchForecastRequest,
|
|
BatchForecastResponse, MultiDayForecastResponse
|
|
)
|
|
from shared.auth.decorators import get_current_user_dep
|
|
from shared.database.base import create_database_manager
|
|
from shared.monitoring.decorators import track_execution_time
|
|
from shared.monitoring.metrics import get_metrics_collector
|
|
from app.core.config import settings
|
|
from app.models import AuditLog
|
|
from shared.routing import RouteBuilder
|
|
from shared.auth.access_control import require_user_role, service_only_access
|
|
from shared.security import create_audit_logger, create_rate_limiter, AuditSeverity, AuditAction
|
|
from shared.subscription.plans import get_forecast_quota, get_forecast_horizon_limit
|
|
from shared.redis_utils import get_redis_client
|
|
|
|
route_builder = RouteBuilder('forecasting')
|
|
logger = structlog.get_logger()
|
|
router = APIRouter(tags=["forecasting-operations"])
|
|
|
|
# Initialize audit logger
|
|
audit_logger = create_audit_logger("forecasting-service", AuditLog)
|
|
|
|
async def get_rate_limiter():
|
|
"""Dependency for rate limiter"""
|
|
redis_client = await get_redis_client()
|
|
return create_rate_limiter(redis_client)
|
|
|
|
|
|
def get_enhanced_forecasting_service():
|
|
"""Dependency injection for EnhancedForecastingService"""
|
|
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
|
return EnhancedForecastingService(database_manager)
|
|
|
|
|
|
def get_enhanced_prediction_service():
|
|
"""Dependency injection for enhanced PredictionService"""
|
|
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
|
return PredictionService(database_manager)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_operations_route("single"),
|
|
response_model=ForecastResponse
|
|
)
|
|
@require_user_role(['viewer', 'member', 'admin', 'owner'])
|
|
@track_execution_time("enhanced_single_forecast_duration_seconds", "forecasting-service")
|
|
async def generate_single_forecast(
|
|
request: ForecastRequest,
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
request_obj: Request = None,
|
|
current_user: dict = Depends(get_current_user_dep),
|
|
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
|
):
|
|
"""Generate a single product forecast with caching support"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
logger.info("Generating single forecast",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
forecast_date=request.forecast_date.isoformat())
|
|
|
|
if metrics:
|
|
metrics.increment_counter("single_forecasts_total")
|
|
|
|
# Initialize cache service
|
|
cache_service = get_forecast_cache_service(settings.REDIS_URL)
|
|
|
|
# Check cache first
|
|
cached_forecast = await cache_service.get_cached_forecast(
|
|
tenant_id=uuid.UUID(tenant_id),
|
|
product_id=uuid.UUID(request.inventory_product_id),
|
|
forecast_date=request.forecast_date
|
|
)
|
|
|
|
if cached_forecast:
|
|
if metrics:
|
|
metrics.increment_counter("forecast_cache_hits_total")
|
|
logger.info("Returning cached forecast",
|
|
tenant_id=tenant_id,
|
|
forecast_id=cached_forecast.get('id'))
|
|
return ForecastResponse(**cached_forecast)
|
|
|
|
# Cache miss - generate forecast
|
|
if metrics:
|
|
metrics.increment_counter("forecast_cache_misses_total")
|
|
|
|
forecast = await enhanced_forecasting_service.generate_forecast(
|
|
tenant_id=tenant_id,
|
|
request=request
|
|
)
|
|
|
|
# Cache the result
|
|
await cache_service.cache_forecast(
|
|
tenant_id=uuid.UUID(tenant_id),
|
|
product_id=uuid.UUID(request.inventory_product_id),
|
|
forecast_date=request.forecast_date,
|
|
forecast_data=forecast.dict()
|
|
)
|
|
|
|
if metrics:
|
|
metrics.increment_counter("single_forecasts_success_total")
|
|
|
|
logger.info("Single forecast generated successfully",
|
|
tenant_id=tenant_id,
|
|
forecast_id=forecast.id)
|
|
|
|
return forecast
|
|
|
|
except ValueError as e:
|
|
if metrics:
|
|
metrics.increment_counter("forecast_validation_errors_total")
|
|
logger.error("Forecast validation error", error=str(e), tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("single_forecasts_errors_total")
|
|
logger.error("Single forecast generation failed", error=str(e), tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Forecast generation failed"
|
|
)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_operations_route("multi-day"),
|
|
response_model=MultiDayForecastResponse
|
|
)
|
|
@require_user_role(['viewer', 'member', 'admin', 'owner'])
|
|
@track_execution_time("enhanced_multi_day_forecast_duration_seconds", "forecasting-service")
|
|
async def generate_multi_day_forecast(
|
|
request: ForecastRequest,
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
request_obj: Request = None,
|
|
current_user: dict = Depends(get_current_user_dep),
|
|
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
|
):
|
|
"""Generate multiple daily forecasts for the specified period"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
logger.info("Generating multi-day forecast",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
forecast_days=request.forecast_days,
|
|
forecast_date=request.forecast_date.isoformat())
|
|
|
|
if metrics:
|
|
metrics.increment_counter("multi_day_forecasts_total")
|
|
|
|
if request.forecast_days <= 0 or request.forecast_days > 30:
|
|
raise ValueError("forecast_days must be between 1 and 30")
|
|
|
|
forecast_result = await enhanced_forecasting_service.generate_multi_day_forecast(
|
|
tenant_id=tenant_id,
|
|
request=request
|
|
)
|
|
|
|
if metrics:
|
|
metrics.increment_counter("multi_day_forecasts_success_total")
|
|
|
|
logger.info("Multi-day forecast generated successfully",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
forecast_days=len(forecast_result.get("forecasts", [])))
|
|
|
|
return forecast_result
|
|
|
|
except ValueError as e:
|
|
if metrics:
|
|
metrics.increment_counter("forecast_validation_errors_total")
|
|
logger.error("Multi-day forecast validation error", error=str(e), tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("multi_day_forecasts_errors_total")
|
|
logger.error("Multi-day forecast generation failed", error=str(e), tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Multi-day forecast generation failed"
|
|
)
|
|
|
|
|
|
async def execute_batch_forecast_background(
|
|
tenant_id: str,
|
|
batch_id: str,
|
|
inventory_product_ids: List[str],
|
|
forecast_days: int,
|
|
batch_name: str
|
|
):
|
|
"""
|
|
Background task for batch forecast generation.
|
|
Prevents blocking the API thread for long-running batch operations.
|
|
"""
|
|
logger.info("Starting background batch forecast",
|
|
batch_id=batch_id,
|
|
tenant_id=tenant_id,
|
|
product_count=len(inventory_product_ids))
|
|
|
|
database_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
|
forecasting_service = EnhancedForecastingService(database_manager)
|
|
|
|
try:
|
|
# Update batch status to running
|
|
async with database_manager.get_session() as session:
|
|
from app.repositories import PredictionBatchRepository
|
|
batch_repo = PredictionBatchRepository(session)
|
|
|
|
await batch_repo.update(
|
|
batch_id,
|
|
{"status": "processing", "completed_products": 0}
|
|
)
|
|
await session.commit()
|
|
|
|
# Generate forecasts for all products
|
|
from app.schemas.forecasts import BatchForecastRequest
|
|
batch_request = BatchForecastRequest(
|
|
tenant_id=tenant_id,
|
|
batch_name=batch_name,
|
|
inventory_product_ids=inventory_product_ids,
|
|
forecast_days=forecast_days
|
|
)
|
|
|
|
result = await forecasting_service.generate_batch_forecasts(
|
|
tenant_id=tenant_id,
|
|
request=batch_request
|
|
)
|
|
|
|
# Update batch status to completed
|
|
async with database_manager.get_session() as session:
|
|
from app.repositories import PredictionBatchRepository
|
|
batch_repo = PredictionBatchRepository(session)
|
|
|
|
await batch_repo.update(
|
|
batch_id,
|
|
{
|
|
"status": "completed",
|
|
"completed_at": datetime.now(timezone.utc),
|
|
"completed_products": result.get("successful_forecasts", 0),
|
|
"failed_products": result.get("failed_forecasts", 0)
|
|
}
|
|
)
|
|
await session.commit()
|
|
|
|
logger.info("Background batch forecast completed",
|
|
batch_id=batch_id,
|
|
successful=result.get("successful_forecasts", 0),
|
|
failed=result.get("failed_forecasts", 0))
|
|
|
|
except Exception as e:
|
|
logger.error("Background batch forecast failed",
|
|
batch_id=batch_id,
|
|
error=str(e))
|
|
|
|
try:
|
|
async with database_manager.get_session() as session:
|
|
from app.repositories import PredictionBatchRepository
|
|
batch_repo = PredictionBatchRepository(session)
|
|
|
|
await batch_repo.update(
|
|
batch_id,
|
|
{
|
|
"status": "failed",
|
|
"completed_at": datetime.now(timezone.utc),
|
|
"error_message": str(e)
|
|
}
|
|
)
|
|
await session.commit()
|
|
except Exception as update_error:
|
|
logger.error("Failed to update batch status after error",
|
|
batch_id=batch_id,
|
|
error=str(update_error))
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_operations_route("batch"),
|
|
response_model=BatchForecastResponse
|
|
)
|
|
@require_user_role(['admin', 'owner'])
|
|
@track_execution_time("enhanced_batch_forecast_duration_seconds", "forecasting-service")
|
|
async def generate_batch_forecast(
|
|
request: BatchForecastRequest,
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
background_tasks: BackgroundTasks = BackgroundTasks(),
|
|
request_obj: Request = None,
|
|
current_user: dict = Depends(get_current_user_dep),
|
|
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
|
):
|
|
"""
|
|
Generate forecasts for multiple products in batch (Admin+ only, quota enforced).
|
|
|
|
IMPROVEMENT: Now uses background tasks for large batches to prevent API timeouts.
|
|
Returns immediately with batch_id for status tracking.
|
|
"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
logger.info("Generating batch forecast",
|
|
tenant_id=tenant_id,
|
|
product_count=len(request.inventory_product_ids))
|
|
|
|
if metrics:
|
|
metrics.increment_counter("batch_forecasts_total")
|
|
|
|
# Check if we need to get all products instead of specific ones
|
|
inventory_product_ids = request.inventory_product_ids
|
|
if inventory_product_ids is None or len(inventory_product_ids) == 0:
|
|
# If no specific products requested, fetch all products for the tenant
|
|
# from the inventory service to generate forecasts for all of them
|
|
from shared.clients.inventory_client import InventoryServiceClient
|
|
from app.core.config import settings
|
|
|
|
inventory_client = InventoryServiceClient(settings)
|
|
all_ingredients = await inventory_client.get_all_ingredients(tenant_id=tenant_id)
|
|
inventory_product_ids = [str(ingredient['id']) for ingredient in all_ingredients] if all_ingredients else []
|
|
|
|
# If still no products, return early with success response
|
|
if not inventory_product_ids:
|
|
logger.info("No products found for forecasting", tenant_id=tenant_id)
|
|
from app.schemas.forecasts import BatchForecastResponse
|
|
now = datetime.now(timezone.utc)
|
|
return BatchForecastResponse(
|
|
id=str(uuid.uuid4()),
|
|
tenant_id=tenant_id,
|
|
batch_name=getattr(request, 'batch_name', f"orchestrator-batch-{datetime.now().strftime('%Y%m%d')}"),
|
|
status="completed",
|
|
total_products=0,
|
|
completed_products=0,
|
|
failed_products=0,
|
|
requested_at=now,
|
|
completed_at=now,
|
|
processing_time_ms=0,
|
|
forecasts=None,
|
|
error_message=None
|
|
)
|
|
|
|
# IMPROVEMENT: For large batches (>5 products), use background task
|
|
# For small batches, execute synchronously for immediate results
|
|
batch_name = getattr(request, 'batch_name', f"batch-{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
|
forecast_days = getattr(request, 'forecast_days', 7)
|
|
|
|
# Create batch record first
|
|
batch_id = str(uuid.uuid4())
|
|
now = datetime.now(timezone.utc)
|
|
|
|
async with enhanced_forecasting_service.database_manager.get_session() as session:
|
|
from app.repositories import PredictionBatchRepository
|
|
batch_repo = PredictionBatchRepository(session)
|
|
|
|
batch_data = {
|
|
"tenant_id": tenant_id,
|
|
"batch_name": batch_name,
|
|
"total_products": len(inventory_product_ids),
|
|
"forecast_days": forecast_days,
|
|
"status": "pending"
|
|
}
|
|
|
|
batch = await batch_repo.create_batch(batch_data)
|
|
batch_id = str(batch.id)
|
|
await session.commit()
|
|
|
|
# Use background task for large batches to prevent API timeout
|
|
use_background = len(inventory_product_ids) > 5
|
|
|
|
if use_background:
|
|
# Queue background task
|
|
background_tasks.add_task(
|
|
execute_batch_forecast_background,
|
|
tenant_id=tenant_id,
|
|
batch_id=batch_id,
|
|
inventory_product_ids=inventory_product_ids,
|
|
forecast_days=forecast_days,
|
|
batch_name=batch_name
|
|
)
|
|
|
|
logger.info("Batch forecast queued for background processing",
|
|
tenant_id=tenant_id,
|
|
batch_id=batch_id,
|
|
product_count=len(inventory_product_ids))
|
|
|
|
# Return immediately with pending status
|
|
from app.schemas.forecasts import BatchForecastResponse
|
|
return BatchForecastResponse(
|
|
id=batch_id,
|
|
tenant_id=tenant_id,
|
|
batch_name=batch_name,
|
|
status="pending",
|
|
total_products=len(inventory_product_ids),
|
|
completed_products=0,
|
|
failed_products=0,
|
|
requested_at=now,
|
|
completed_at=None,
|
|
processing_time_ms=0,
|
|
forecasts=None,
|
|
error_message=None
|
|
)
|
|
else:
|
|
# Small batch - execute synchronously
|
|
from app.schemas.forecasts import BatchForecastRequest
|
|
updated_request = BatchForecastRequest(
|
|
tenant_id=tenant_id,
|
|
batch_name=batch_name,
|
|
inventory_product_ids=inventory_product_ids,
|
|
forecast_days=forecast_days
|
|
)
|
|
|
|
batch_result = await enhanced_forecasting_service.generate_batch_forecasts(
|
|
tenant_id=tenant_id,
|
|
request=updated_request
|
|
)
|
|
|
|
if metrics:
|
|
metrics.increment_counter("batch_forecasts_success_total")
|
|
|
|
logger.info("Batch forecast completed synchronously",
|
|
tenant_id=tenant_id,
|
|
total_forecasts=batch_result.get('total_forecasts', 0))
|
|
|
|
# Convert the service result to BatchForecastResponse format
|
|
from app.schemas.forecasts import BatchForecastResponse
|
|
return BatchForecastResponse(
|
|
id=batch_id,
|
|
tenant_id=tenant_id,
|
|
batch_name=batch_name,
|
|
status="completed",
|
|
total_products=batch_result.get('total_forecasts', 0),
|
|
completed_products=batch_result.get('successful_forecasts', 0),
|
|
failed_products=batch_result.get('failed_forecasts', 0),
|
|
requested_at=now,
|
|
completed_at=datetime.now(timezone.utc),
|
|
processing_time_ms=0,
|
|
forecasts=[],
|
|
error_message=None
|
|
)
|
|
|
|
except ValueError as e:
|
|
if metrics:
|
|
metrics.increment_counter("forecast_validation_errors_total")
|
|
logger.error("Batch forecast validation error", error=str(e), tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("batch_forecasts_errors_total")
|
|
logger.error("Batch forecast generation failed", error=str(e), tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Batch forecast generation failed"
|
|
)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_operations_route("realtime")
|
|
)
|
|
@require_user_role(['viewer', 'member', 'admin', 'owner'])
|
|
@track_execution_time("enhanced_realtime_prediction_duration_seconds", "forecasting-service")
|
|
async def generate_realtime_prediction(
|
|
prediction_request: Dict[str, Any],
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
request_obj: Request = None,
|
|
current_user: dict = Depends(get_current_user_dep),
|
|
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
|
|
):
|
|
"""Generate real-time prediction"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
logger.info("Generating real-time prediction",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=prediction_request.get("inventory_product_id"))
|
|
|
|
if metrics:
|
|
metrics.increment_counter("realtime_predictions_total")
|
|
|
|
required_fields = ["inventory_product_id", "model_id", "features"]
|
|
missing_fields = [field for field in required_fields if field not in prediction_request]
|
|
if missing_fields:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Missing required fields: {missing_fields}"
|
|
)
|
|
|
|
prediction_result = await prediction_service.predict_with_weather_forecast(
|
|
model_id=prediction_request["model_id"],
|
|
model_path=prediction_request.get("model_path", ""),
|
|
features=prediction_request["features"],
|
|
tenant_id=tenant_id,
|
|
days=prediction_request.get("days", 7),
|
|
confidence_level=prediction_request.get("confidence_level", 0.8)
|
|
)
|
|
|
|
if metrics:
|
|
metrics.increment_counter("realtime_predictions_success_total")
|
|
|
|
logger.info("Real-time prediction generated successfully",
|
|
tenant_id=tenant_id,
|
|
days=len(prediction_result))
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": prediction_request["inventory_product_id"],
|
|
"model_id": prediction_request["model_id"],
|
|
"predictions": prediction_result,
|
|
"days": len(prediction_result),
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except ValueError as e:
|
|
if metrics:
|
|
metrics.increment_counter("prediction_validation_errors_total")
|
|
logger.error("Prediction validation error", error=str(e), tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
if metrics:
|
|
metrics.increment_counter("realtime_predictions_errors_total")
|
|
logger.error("Real-time prediction failed", error=str(e), tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Real-time prediction failed"
|
|
)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_operations_route("batch-predictions")
|
|
)
|
|
@require_user_role(['viewer', 'member', 'admin', 'owner'])
|
|
async def generate_batch_predictions(
|
|
predictions_request: List[Dict[str, Any]],
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
current_user: dict = Depends(get_current_user_dep),
|
|
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
|
|
):
|
|
"""Generate batch predictions"""
|
|
try:
|
|
logger.info("Generating batch predictions", tenant_id=tenant_id, count=len(predictions_request))
|
|
|
|
results = []
|
|
for pred_request in predictions_request:
|
|
try:
|
|
prediction_result = await prediction_service.predict_with_weather_forecast(
|
|
model_id=pred_request["model_id"],
|
|
model_path=pred_request.get("model_path", ""),
|
|
features=pred_request["features"],
|
|
tenant_id=tenant_id,
|
|
days=pred_request.get("days", 7),
|
|
confidence_level=pred_request.get("confidence_level", 0.8)
|
|
)
|
|
results.append({
|
|
"inventory_product_id": pred_request.get("inventory_product_id"),
|
|
"predictions": prediction_result,
|
|
"success": True
|
|
})
|
|
except Exception as e:
|
|
results.append({
|
|
"inventory_product_id": pred_request.get("inventory_product_id"),
|
|
"error": str(e),
|
|
"success": False
|
|
})
|
|
|
|
return {"predictions": results, "total": len(results)}
|
|
|
|
except Exception as e:
|
|
logger.error("Batch predictions failed", error=str(e), tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Batch predictions failed"
|
|
)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_operations_route("validate-predictions")
|
|
)
|
|
async def validate_predictions(
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
start_date: date = Query(...),
|
|
end_date: date = Query(...),
|
|
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
|
):
|
|
"""Validate predictions against actual sales data"""
|
|
try:
|
|
logger.info("Validating predictions", tenant_id=tenant_id)
|
|
|
|
validation_results = await enhanced_forecasting_service.validate_predictions(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date
|
|
)
|
|
|
|
return validation_results
|
|
|
|
except Exception as e:
|
|
logger.error("Prediction validation failed", error=str(e), tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Prediction validation failed"
|
|
)
|
|
|
|
|
|
@router.get(
|
|
route_builder.build_operations_route("statistics")
|
|
)
|
|
async def get_forecast_statistics(
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
start_date: Optional[date] = Query(None),
|
|
end_date: Optional[date] = Query(None),
|
|
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
|
):
|
|
"""Get forecast statistics"""
|
|
try:
|
|
logger.info("Getting forecast statistics", tenant_id=tenant_id)
|
|
|
|
stats = await enhanced_forecasting_service.get_forecast_statistics(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date
|
|
)
|
|
|
|
return stats
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get forecast statistics", error=str(e), tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to retrieve forecast statistics"
|
|
)
|
|
|
|
|
|
@router.delete(
|
|
route_builder.build_operations_route("cache")
|
|
)
|
|
async def clear_prediction_cache(
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
prediction_service: PredictionService = Depends(get_enhanced_prediction_service)
|
|
):
|
|
"""Clear prediction cache"""
|
|
try:
|
|
logger.info("Clearing prediction cache", tenant_id=tenant_id)
|
|
|
|
await prediction_service.clear_cache(tenant_id=tenant_id)
|
|
|
|
return {"message": "Prediction cache cleared successfully"}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to clear prediction cache", error=str(e), tenant_id=tenant_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to clear prediction cache"
|
|
)
|
|
|
|
|
|
@router.post(
|
|
route_builder.build_operations_route("validate-forecasts"),
|
|
response_model=dict
|
|
)
|
|
@service_only_access
|
|
@track_execution_time("validate_forecasts_duration_seconds", "forecasting-service")
|
|
async def validate_forecasts(
|
|
validation_date: date = Query(..., description="Date to validate forecasts for"),
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
request_obj: Request = None,
|
|
current_user: dict = Depends(get_current_user_dep),
|
|
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
|
):
|
|
"""
|
|
Validate forecasts for a specific date against actual sales.
|
|
Calculates MAPE, RMSE, MAE and identifies products with poor accuracy.
|
|
|
|
This endpoint is called by the orchestrator during Step 5 to validate
|
|
yesterday's forecasts and trigger retraining if needed.
|
|
|
|
Args:
|
|
validation_date: Date to validate forecasts for
|
|
tenant_id: Tenant ID
|
|
|
|
Returns:
|
|
Dict with overall metrics and poor accuracy products list:
|
|
- overall_mape: Mean Absolute Percentage Error across all products
|
|
- overall_rmse: Root Mean Squared Error across all products
|
|
- overall_mae: Mean Absolute Error across all products
|
|
- products_validated: Number of products validated
|
|
- poor_accuracy_products: List of products with MAPE > 30%
|
|
"""
|
|
metrics = get_metrics_collector(request_obj)
|
|
|
|
try:
|
|
logger.info("Validating forecasts for date",
|
|
tenant_id=tenant_id,
|
|
validation_date=validation_date.isoformat())
|
|
|
|
if metrics:
|
|
metrics.increment_counter("forecast_validations_total")
|
|
|
|
# Get all forecasts for the validation date
|
|
from app.repositories.forecast_repository import ForecastRepository
|
|
from shared.clients.sales_client import SalesServiceClient
|
|
|
|
db_manager = create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
|
|
|
async with db_manager.get_session() as session:
|
|
forecast_repo = ForecastRepository(session)
|
|
|
|
# Get forecasts for the validation date
|
|
forecasts = await forecast_repo.get_forecasts_by_date(
|
|
tenant_id=uuid.UUID(tenant_id),
|
|
forecast_date=validation_date
|
|
)
|
|
|
|
if not forecasts:
|
|
logger.warning("No forecasts found for validation date",
|
|
tenant_id=tenant_id,
|
|
validation_date=validation_date.isoformat())
|
|
return {
|
|
"overall_mape": 0,
|
|
"overall_rmse": 0,
|
|
"overall_mae": 0,
|
|
"products_validated": 0,
|
|
"poor_accuracy_products": []
|
|
}
|
|
|
|
# Get actual sales for the validation date from sales service
|
|
sales_client = SalesServiceClient(settings, "forecasting-service")
|
|
actual_sales_response = await sales_client.get_sales_by_date_range(
|
|
tenant_id=tenant_id,
|
|
start_date=validation_date,
|
|
end_date=validation_date
|
|
)
|
|
|
|
# Create sales lookup dict
|
|
sales_dict = {}
|
|
if actual_sales_response and 'sales' in actual_sales_response:
|
|
for sale in actual_sales_response['sales']:
|
|
product_id = sale.get('inventory_product_id')
|
|
quantity = sale.get('quantity', 0)
|
|
if product_id:
|
|
# Aggregate quantities for the same product
|
|
sales_dict[product_id] = sales_dict.get(product_id, 0) + quantity
|
|
|
|
# Calculate metrics for each product
|
|
import numpy as np
|
|
|
|
mape_list = []
|
|
rmse_list = []
|
|
mae_list = []
|
|
poor_accuracy_products = []
|
|
|
|
for forecast in forecasts:
|
|
product_id = str(forecast.inventory_product_id)
|
|
actual_quantity = sales_dict.get(product_id)
|
|
|
|
# Skip if no actual sales data
|
|
if actual_quantity is None:
|
|
continue
|
|
|
|
predicted_quantity = forecast.predicted_demand
|
|
|
|
# Calculate errors
|
|
absolute_error = abs(predicted_quantity - actual_quantity)
|
|
squared_error = (predicted_quantity - actual_quantity) ** 2
|
|
|
|
# Calculate percentage error (avoid division by zero)
|
|
if actual_quantity > 0:
|
|
percentage_error = (absolute_error / actual_quantity) * 100
|
|
else:
|
|
# If actual is 0 but predicted is not, treat as 100% error
|
|
percentage_error = 100 if predicted_quantity > 0 else 0
|
|
|
|
mape_list.append(percentage_error)
|
|
rmse_list.append(squared_error)
|
|
mae_list.append(absolute_error)
|
|
|
|
# Track products with poor accuracy
|
|
if percentage_error > 30:
|
|
poor_accuracy_products.append({
|
|
"product_id": product_id,
|
|
"mape": round(percentage_error, 2),
|
|
"predicted": round(predicted_quantity, 2),
|
|
"actual": round(actual_quantity, 2)
|
|
})
|
|
|
|
# Calculate overall metrics
|
|
overall_mape = np.mean(mape_list) if mape_list else 0
|
|
overall_rmse = np.sqrt(np.mean(rmse_list)) if rmse_list else 0
|
|
overall_mae = np.mean(mae_list) if mae_list else 0
|
|
|
|
result = {
|
|
"overall_mape": round(overall_mape, 2),
|
|
"overall_rmse": round(overall_rmse, 2),
|
|
"overall_mae": round(overall_mae, 2),
|
|
"products_validated": len(mape_list),
|
|
"poor_accuracy_products": poor_accuracy_products
|
|
}
|
|
|
|
logger.info("Forecast validation complete",
|
|
tenant_id=tenant_id,
|
|
validation_date=validation_date.isoformat(),
|
|
overall_mape=result["overall_mape"],
|
|
products_validated=result["products_validated"],
|
|
poor_accuracy_count=len(poor_accuracy_products))
|
|
|
|
if metrics:
|
|
metrics.increment_counter("forecast_validations_completed_total")
|
|
metrics.observe_histogram("forecast_validation_mape", overall_mape)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to validate forecasts",
|
|
error=str(e),
|
|
tenant_id=tenant_id,
|
|
validation_date=validation_date.isoformat())
|
|
if metrics:
|
|
metrics.increment_counter("forecast_validations_failed_total")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to validate forecasts: {str(e)}"
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Tenant Data Deletion Operations (Internal Service Only)
|
|
# ============================================================================
|
|
|
|
@router.delete(
|
|
route_builder.build_base_route("tenant/{tenant_id}", include_tenant_prefix=False),
|
|
response_model=dict
|
|
)
|
|
@service_only_access
|
|
async def delete_tenant_data(
|
|
tenant_id: str = Path(..., description="Tenant ID to delete data for"),
|
|
current_user: dict = Depends(get_current_user_dep)
|
|
):
|
|
"""
|
|
Delete all forecasting data for a tenant (Internal service only)
|
|
|
|
This endpoint is called by the orchestrator during tenant deletion.
|
|
It permanently deletes all forecasting-related data including:
|
|
- Forecasts (all time periods)
|
|
- Prediction batches
|
|
- Model performance metrics
|
|
- Prediction cache
|
|
- Audit logs
|
|
|
|
**WARNING**: This operation is irreversible!
|
|
|
|
Returns:
|
|
Deletion summary with counts of deleted records
|
|
"""
|
|
from app.services.tenant_deletion_service import ForecastingTenantDeletionService
|
|
|
|
try:
|
|
logger.info("forecasting.tenant_deletion.api_called", tenant_id=tenant_id)
|
|
|
|
db_manager = create_database_manager(settings.DATABASE_URL, "forecasting")
|
|
|
|
async with db_manager.get_session() as session:
|
|
deletion_service = ForecastingTenantDeletionService(session)
|
|
result = await deletion_service.safe_delete_tenant_data(tenant_id)
|
|
|
|
if not result.success:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Tenant data deletion failed: {', '.join(result.errors)}"
|
|
)
|
|
|
|
return {
|
|
"message": "Tenant data deletion completed successfully",
|
|
"summary": result.to_dict()
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("forecasting.tenant_deletion.api_error",
|
|
tenant_id=tenant_id,
|
|
error=str(e),
|
|
exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to delete tenant data: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get(
|
|
route_builder.build_base_route("tenant/{tenant_id}/deletion-preview", include_tenant_prefix=False),
|
|
response_model=dict
|
|
)
|
|
@service_only_access
|
|
async def preview_tenant_data_deletion(
|
|
tenant_id: str = Path(..., description="Tenant ID to preview deletion for"),
|
|
current_user: dict = Depends(get_current_user_dep)
|
|
):
|
|
"""
|
|
Preview what data would be deleted for a tenant (dry-run)
|
|
|
|
This endpoint shows counts of all data that would be deleted
|
|
without actually deleting anything. Useful for:
|
|
- Confirming deletion scope before execution
|
|
- Auditing and compliance
|
|
- Troubleshooting
|
|
|
|
Returns:
|
|
Dictionary with entity names and their counts
|
|
"""
|
|
from app.services.tenant_deletion_service import ForecastingTenantDeletionService
|
|
|
|
try:
|
|
logger.info("forecasting.tenant_deletion.preview_called", tenant_id=tenant_id)
|
|
|
|
db_manager = create_database_manager(settings.DATABASE_URL, "forecasting")
|
|
|
|
async with db_manager.get_session() as session:
|
|
deletion_service = ForecastingTenantDeletionService(session)
|
|
preview = await deletion_service.get_tenant_data_preview(tenant_id)
|
|
|
|
total_records = sum(preview.values())
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"service": "forecasting",
|
|
"preview": preview,
|
|
"total_records": total_records,
|
|
"warning": "These records will be permanently deleted and cannot be recovered"
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("forecasting.tenant_deletion.preview_error",
|
|
tenant_id=tenant_id,
|
|
error=str(e),
|
|
exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to preview tenant data deletion: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get("/health/database")
|
|
async def database_health():
|
|
"""
|
|
Database health check endpoint with connection pool monitoring.
|
|
|
|
Returns detailed connection pool statistics for monitoring and alerting.
|
|
Useful for detecting connection pool exhaustion before it causes issues.
|
|
"""
|
|
from app.core.database import get_db_health, get_connection_pool_stats
|
|
from datetime import datetime
|
|
|
|
try:
|
|
# Check database connectivity
|
|
db_healthy = await get_db_health()
|
|
|
|
# Get connection pool statistics
|
|
pool_stats = await get_connection_pool_stats()
|
|
|
|
response = {
|
|
"service": "forecasting",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"database_connected": db_healthy,
|
|
"connection_pool": pool_stats,
|
|
"overall_status": "healthy" if db_healthy and pool_stats.get("status") == "healthy" else "degraded"
|
|
}
|
|
|
|
# Return appropriate status code based on health
|
|
if not db_healthy or pool_stats.get("status") == "critical":
|
|
return JSONResponse(status_code=503, content=response)
|
|
elif pool_stats.get("status") == "warning":
|
|
return JSONResponse(status_code=200, content=response)
|
|
else:
|
|
return response
|
|
|
|
except Exception as e:
|
|
logger.error("Health check failed", error=str(e))
|
|
return JSONResponse(
|
|
status_code=503,
|
|
content={
|
|
"service": "forecasting",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"overall_status": "unhealthy",
|
|
"error": str(e)
|
|
}
|
|
)
|