518 lines
19 KiB
Python
518 lines
19 KiB
Python
# ================================================================
|
|
# services/forecasting/app/api/forecasts.py
|
|
# ================================================================
|
|
"""
|
|
Forecast API endpoints
|
|
"""
|
|
|
|
import structlog
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from typing import List, Optional
|
|
from datetime import date, datetime
|
|
from sqlalchemy import select, delete, func
|
|
import uuid
|
|
|
|
from app.core.database import get_db
|
|
from shared.auth.decorators import (
|
|
get_current_user_dep,
|
|
require_admin_role
|
|
)
|
|
from app.services.forecasting_service import ForecastingService
|
|
from app.schemas.forecasts import (
|
|
ForecastRequest, ForecastResponse, BatchForecastRequest,
|
|
BatchForecastResponse, AlertResponse
|
|
)
|
|
from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert
|
|
from app.services.messaging import publish_forecasts_deleted_event
|
|
|
|
logger = structlog.get_logger()
|
|
router = APIRouter()
|
|
|
|
# Initialize service
|
|
forecasting_service = ForecastingService()
|
|
|
|
@router.post("/tenants/{tenant_id}/forecasts/single", response_model=ForecastResponse)
|
|
async def create_single_forecast(
|
|
request: ForecastRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
tenant_id: str = Path(..., description="Tenant ID")
|
|
):
|
|
"""Generate a single product forecast"""
|
|
|
|
try:
|
|
|
|
# Generate forecast
|
|
forecast = await forecasting_service.generate_forecast(tenant_id, request, db)
|
|
|
|
# Convert to response model
|
|
return ForecastResponse(
|
|
id=str(forecast.id),
|
|
tenant_id=tenant_id,
|
|
product_name=forecast.product_name,
|
|
location=forecast.location,
|
|
forecast_date=forecast.forecast_date,
|
|
predicted_demand=forecast.predicted_demand,
|
|
confidence_lower=forecast.confidence_lower,
|
|
confidence_upper=forecast.confidence_upper,
|
|
confidence_level=forecast.confidence_level,
|
|
model_id=str(forecast.model_id),
|
|
model_version=forecast.model_version,
|
|
algorithm=forecast.algorithm,
|
|
business_type=forecast.business_type,
|
|
is_holiday=forecast.is_holiday,
|
|
is_weekend=forecast.is_weekend,
|
|
day_of_week=forecast.day_of_week,
|
|
weather_temperature=forecast.weather_temperature,
|
|
weather_precipitation=forecast.weather_precipitation,
|
|
weather_description=forecast.weather_description,
|
|
traffic_volume=forecast.traffic_volume,
|
|
created_at=forecast.created_at,
|
|
processing_time_ms=forecast.processing_time_ms,
|
|
features_used=forecast.features_used
|
|
)
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
logger.error("Error creating single forecast", error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Internal server error"
|
|
)
|
|
|
|
@router.post("/tenants/{tenant_id}/forecasts/batch", response_model=BatchForecastResponse)
|
|
async def create_batch_forecast(
|
|
request: BatchForecastRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
current_user: dict = Depends(get_current_user_dep)
|
|
):
|
|
"""Generate batch forecasts for multiple products"""
|
|
|
|
try:
|
|
# Verify tenant access
|
|
if str(request.tenant_id) != tenant_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Access denied to this tenant"
|
|
)
|
|
|
|
# Generate batch forecast
|
|
batch = await forecasting_service.generate_batch_forecast(request, db)
|
|
|
|
# Get associated forecasts
|
|
forecasts = await forecasting_service.get_forecasts(
|
|
tenant_id=request.tenant_id,
|
|
location=request.location,
|
|
db=db
|
|
)
|
|
|
|
# Convert forecasts to response models
|
|
forecast_responses = []
|
|
for forecast in forecasts[:batch.total_products]: # Limit to batch size
|
|
forecast_responses.append(ForecastResponse(
|
|
id=str(forecast.id),
|
|
tenant_id=str(forecast.tenant_id),
|
|
product_name=forecast.product_name,
|
|
location=forecast.location,
|
|
forecast_date=forecast.forecast_date,
|
|
predicted_demand=forecast.predicted_demand,
|
|
confidence_lower=forecast.confidence_lower,
|
|
confidence_upper=forecast.confidence_upper,
|
|
confidence_level=forecast.confidence_level,
|
|
model_id=str(forecast.model_id),
|
|
model_version=forecast.model_version,
|
|
algorithm=forecast.algorithm,
|
|
business_type=forecast.business_type,
|
|
is_holiday=forecast.is_holiday,
|
|
is_weekend=forecast.is_weekend,
|
|
day_of_week=forecast.day_of_week,
|
|
weather_temperature=forecast.weather_temperature,
|
|
weather_precipitation=forecast.weather_precipitation,
|
|
weather_description=forecast.weather_description,
|
|
traffic_volume=forecast.traffic_volume,
|
|
created_at=forecast.created_at,
|
|
processing_time_ms=forecast.processing_time_ms,
|
|
features_used=forecast.features_used
|
|
))
|
|
|
|
return BatchForecastResponse(
|
|
id=str(batch.id),
|
|
tenant_id=str(batch.tenant_id),
|
|
batch_name=batch.batch_name,
|
|
status=batch.status,
|
|
total_products=batch.total_products,
|
|
completed_products=batch.completed_products,
|
|
failed_products=batch.failed_products,
|
|
requested_at=batch.requested_at,
|
|
completed_at=batch.completed_at,
|
|
processing_time_ms=batch.processing_time_ms,
|
|
forecasts=forecast_responses
|
|
)
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e)
|
|
)
|
|
except Exception as e:
|
|
logger.error("Error creating batch forecast", error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Internal server error"
|
|
)
|
|
|
|
@router.get("/tenants/{tenant_id}/forecasts/list", response_model=List[ForecastResponse])
|
|
async def list_forecasts(
|
|
location: str,
|
|
start_date: Optional[date] = Query(None),
|
|
end_date: Optional[date] = Query(None),
|
|
product_name: Optional[str] = Query(None),
|
|
db: AsyncSession = Depends(get_db),
|
|
tenant_id: str = Path(..., description="Tenant ID")
|
|
):
|
|
"""List forecasts with filtering"""
|
|
|
|
try:
|
|
|
|
# Get forecasts
|
|
forecasts = await forecasting_service.get_forecasts(
|
|
tenant_id=tenant_id,
|
|
location=location,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
product_name=product_name,
|
|
db=db
|
|
)
|
|
|
|
# Convert to response models
|
|
return [
|
|
ForecastResponse(
|
|
id=str(forecast.id),
|
|
tenant_id=str(forecast.tenant_id),
|
|
product_name=forecast.product_name,
|
|
location=forecast.location,
|
|
forecast_date=forecast.forecast_date,
|
|
predicted_demand=forecast.predicted_demand,
|
|
confidence_lower=forecast.confidence_lower,
|
|
confidence_upper=forecast.confidence_upper,
|
|
confidence_level=forecast.confidence_level,
|
|
model_id=str(forecast.model_id),
|
|
model_version=forecast.model_version,
|
|
algorithm=forecast.algorithm,
|
|
business_type=forecast.business_type,
|
|
is_holiday=forecast.is_holiday,
|
|
is_weekend=forecast.is_weekend,
|
|
day_of_week=forecast.day_of_week,
|
|
weather_temperature=forecast.weather_temperature,
|
|
weather_precipitation=forecast.weather_precipitation,
|
|
weather_description=forecast.weather_description,
|
|
traffic_volume=forecast.traffic_volume,
|
|
created_at=forecast.created_at,
|
|
processing_time_ms=forecast.processing_time_ms,
|
|
features_used=forecast.features_used
|
|
)
|
|
for forecast in forecasts
|
|
]
|
|
|
|
except Exception as e:
|
|
logger.error("Error listing forecasts", error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Internal server error"
|
|
)
|
|
|
|
@router.get("/tenants/{tenant_id}/forecasts/alerts", response_model=List[AlertResponse])
|
|
async def get_forecast_alerts(
|
|
active_only: bool = Query(True),
|
|
db: AsyncSession = Depends(get_db),
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
current_user: dict = Depends(get_current_user_dep)
|
|
):
|
|
"""Get forecast alerts for tenant"""
|
|
|
|
try:
|
|
from sqlalchemy import select, and_
|
|
|
|
# Build query
|
|
query = select(ForecastAlert).where(
|
|
ForecastAlert.tenant_id == tenant_id
|
|
)
|
|
|
|
if active_only:
|
|
query = query.where(ForecastAlert.is_active == True)
|
|
|
|
query = query.order_by(ForecastAlert.created_at.desc())
|
|
|
|
# Execute query
|
|
result = await db.execute(query)
|
|
alerts = result.scalars().all()
|
|
|
|
# Convert to response models
|
|
return [
|
|
AlertResponse(
|
|
id=str(alert.id),
|
|
tenant_id=str(alert.tenant_id),
|
|
forecast_id=str(alert.forecast_id),
|
|
alert_type=alert.alert_type,
|
|
severity=alert.severity,
|
|
message=alert.message,
|
|
is_active=alert.is_active,
|
|
created_at=alert.created_at,
|
|
acknowledged_at=alert.acknowledged_at,
|
|
notification_sent=alert.notification_sent
|
|
)
|
|
for alert in alerts
|
|
]
|
|
|
|
except Exception as e:
|
|
logger.error("Error getting forecast alerts", error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Internal server error"
|
|
)
|
|
|
|
@router.put("/tenants/{tenant_id}/forecasts/alerts/{alert_id}/acknowledge")
|
|
async def acknowledge_alert(
|
|
alert_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
tenant_id: str = Path(..., description="Tenant ID"),
|
|
current_user: dict = Depends(get_current_user_dep)
|
|
):
|
|
"""Acknowledge a forecast alert"""
|
|
|
|
try:
|
|
from sqlalchemy import select, update
|
|
from datetime import datetime
|
|
|
|
# Get alert
|
|
result = await db.execute(
|
|
select(ForecastAlert).where(
|
|
and_(
|
|
ForecastAlert.id == alert_id,
|
|
ForecastAlert.tenant_id == tenant_id
|
|
)
|
|
)
|
|
)
|
|
alert = result.scalar_one_or_none()
|
|
|
|
if not alert:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Alert not found"
|
|
)
|
|
|
|
# Update alert
|
|
alert.acknowledged_at = datetime.now()
|
|
alert.is_active = False
|
|
|
|
await db.commit()
|
|
|
|
return {"message": "Alert acknowledged successfully"}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Error acknowledging alert", error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Internal server error"
|
|
)
|
|
|
|
@router.delete("/forecasts/tenant/{tenant_id}")
|
|
async def delete_tenant_forecasts_complete(
|
|
tenant_id: str,
|
|
current_user = Depends(get_current_user_dep),
|
|
_admin_check = Depends(require_admin_role),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
Delete all forecasts and predictions for a tenant.
|
|
|
|
**WARNING: This operation is irreversible!**
|
|
|
|
This endpoint:
|
|
1. Cancels any active prediction batches
|
|
2. Clears prediction cache
|
|
3. Deletes all forecast records
|
|
4. Deletes prediction batch records
|
|
5. Deletes model performance metrics
|
|
6. Publishes deletion event
|
|
|
|
Used by admin user deletion process to clean up all forecasting data.
|
|
"""
|
|
|
|
try:
|
|
tenant_uuid = uuid.UUID(tenant_id)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid tenant ID format"
|
|
)
|
|
|
|
try:
|
|
from app.models.forecasts import Forecast, PredictionBatch
|
|
from app.models.predictions import ModelPerformanceMetric, PredictionCache
|
|
|
|
deletion_stats = {
|
|
"tenant_id": tenant_id,
|
|
"deleted_at": datetime.utcnow().isoformat(),
|
|
"batches_cancelled": 0,
|
|
"forecasts_deleted": 0,
|
|
"prediction_batches_deleted": 0,
|
|
"performance_metrics_deleted": 0,
|
|
"cache_entries_deleted": 0,
|
|
"errors": []
|
|
}
|
|
|
|
# Step 1: Cancel active prediction batches
|
|
try:
|
|
active_batches_query = select(PredictionBatch).where(
|
|
PredictionBatch.tenant_id == tenant_uuid,
|
|
PredictionBatch.status.in_(["pending", "processing"])
|
|
)
|
|
active_batches_result = await db.execute(active_batches_query)
|
|
active_batches = active_batches_result.scalars().all()
|
|
|
|
for batch in active_batches:
|
|
batch.status = "cancelled"
|
|
batch.completed_at = datetime.utcnow()
|
|
deletion_stats["batches_cancelled"] += 1
|
|
|
|
if active_batches:
|
|
await db.commit()
|
|
logger.info("Cancelled active prediction batches",
|
|
tenant_id=tenant_id,
|
|
count=len(active_batches))
|
|
|
|
except Exception as e:
|
|
error_msg = f"Error cancelling prediction batches: {str(e)}"
|
|
deletion_stats["errors"].append(error_msg)
|
|
logger.error(error_msg)
|
|
|
|
# Step 2: Delete prediction cache
|
|
try:
|
|
cache_count_query = select(func.count(PredictionCache.id)).where(
|
|
PredictionCache.tenant_id == tenant_uuid
|
|
)
|
|
cache_count_result = await db.execute(cache_count_query)
|
|
cache_count = cache_count_result.scalar()
|
|
|
|
cache_delete_query = delete(PredictionCache).where(
|
|
PredictionCache.tenant_id == tenant_uuid
|
|
)
|
|
await db.execute(cache_delete_query)
|
|
deletion_stats["cache_entries_deleted"] = cache_count
|
|
|
|
logger.info("Deleted prediction cache entries",
|
|
tenant_id=tenant_id,
|
|
count=cache_count)
|
|
|
|
except Exception as e:
|
|
error_msg = f"Error deleting prediction cache: {str(e)}"
|
|
deletion_stats["errors"].append(error_msg)
|
|
logger.error(error_msg)
|
|
|
|
# Step 3: Delete model performance metrics
|
|
try:
|
|
metrics_count_query = select(func.count(ModelPerformanceMetric.id)).where(
|
|
ModelPerformanceMetric.tenant_id == tenant_uuid
|
|
)
|
|
metrics_count_result = await db.execute(metrics_count_query)
|
|
metrics_count = metrics_count_result.scalar()
|
|
|
|
metrics_delete_query = delete(ModelPerformanceMetric).where(
|
|
ModelPerformanceMetric.tenant_id == tenant_uuid
|
|
)
|
|
await db.execute(metrics_delete_query)
|
|
deletion_stats["performance_metrics_deleted"] = metrics_count
|
|
|
|
logger.info("Deleted performance metrics",
|
|
tenant_id=tenant_id,
|
|
count=metrics_count)
|
|
|
|
except Exception as e:
|
|
error_msg = f"Error deleting performance metrics: {str(e)}"
|
|
deletion_stats["errors"].append(error_msg)
|
|
logger.error(error_msg)
|
|
|
|
# Step 4: Delete prediction batches
|
|
try:
|
|
batches_count_query = select(func.count(PredictionBatch.id)).where(
|
|
PredictionBatch.tenant_id == tenant_uuid
|
|
)
|
|
batches_count_result = await db.execute(batches_count_query)
|
|
batches_count = batches_count_result.scalar()
|
|
|
|
batches_delete_query = delete(PredictionBatch).where(
|
|
PredictionBatch.tenant_id == tenant_uuid
|
|
)
|
|
await db.execute(batches_delete_query)
|
|
deletion_stats["prediction_batches_deleted"] = batches_count
|
|
|
|
logger.info("Deleted prediction batches",
|
|
tenant_id=tenant_id,
|
|
count=batches_count)
|
|
|
|
except Exception as e:
|
|
error_msg = f"Error deleting prediction batches: {str(e)}"
|
|
deletion_stats["errors"].append(error_msg)
|
|
logger.error(error_msg)
|
|
|
|
# Step 5: Delete forecasts (main data)
|
|
try:
|
|
forecasts_count_query = select(func.count(Forecast.id)).where(
|
|
Forecast.tenant_id == tenant_uuid
|
|
)
|
|
forecasts_count_result = await db.execute(forecasts_count_query)
|
|
forecasts_count = forecasts_count_result.scalar()
|
|
|
|
forecasts_delete_query = delete(Forecast).where(
|
|
Forecast.tenant_id == tenant_uuid
|
|
)
|
|
await db.execute(forecasts_delete_query)
|
|
deletion_stats["forecasts_deleted"] = forecasts_count
|
|
|
|
await db.commit()
|
|
|
|
logger.info("Deleted forecasts",
|
|
tenant_id=tenant_id,
|
|
count=forecasts_count)
|
|
|
|
except Exception as e:
|
|
await db.rollback()
|
|
error_msg = f"Error deleting forecasts: {str(e)}"
|
|
deletion_stats["errors"].append(error_msg)
|
|
logger.error(error_msg)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=error_msg
|
|
)
|
|
|
|
# Step 6: Publish deletion event
|
|
try:
|
|
await publish_forecasts_deleted_event(tenant_id, deletion_stats)
|
|
except Exception as e:
|
|
logger.warning("Failed to publish forecasts deletion event", error=str(e))
|
|
|
|
return {
|
|
"success": True,
|
|
"message": f"All forecasting data for tenant {tenant_id} deleted successfully",
|
|
"deletion_details": deletion_stats
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Unexpected error deleting tenant forecasts",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to delete tenant forecasts: {str(e)}"
|
|
)
|