Add forecasting service
This commit is contained in:
326
services/forecasting/app/api/forecasts.py
Normal file
326
services/forecasting/app/api/forecasts.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/forecasts.py
|
||||
# ================================================================
|
||||
"""
|
||||
Forecast API endpoints
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Optional
|
||||
from datetime import date
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import get_current_user_from_headers
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
from app.schemas.forecasts import (
|
||||
ForecastRequest, ForecastResponse, BatchForecastRequest,
|
||||
BatchForecastResponse, AlertResponse
|
||||
)
|
||||
from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize service
|
||||
forecasting_service = ForecastingService()
|
||||
|
||||
@router.post("/single", response_model=ForecastResponse)
|
||||
async def create_single_forecast(
|
||||
request: ForecastRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
):
|
||||
"""Generate a single product forecast"""
|
||||
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(request.tenant_id) != str(current_user.get("tenant_id")):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant"
|
||||
)
|
||||
|
||||
# Generate forecast
|
||||
forecast = await forecasting_service.generate_forecast(request, db)
|
||||
|
||||
# Convert to response model
|
||||
return ForecastResponse(
|
||||
id=str(forecast.id),
|
||||
tenant_id=str(forecast.tenant_id),
|
||||
product_name=forecast.product_name,
|
||||
location=forecast.location,
|
||||
forecast_date=forecast.forecast_date,
|
||||
predicted_demand=forecast.predicted_demand,
|
||||
confidence_lower=forecast.confidence_lower,
|
||||
confidence_upper=forecast.confidence_upper,
|
||||
confidence_level=forecast.confidence_level,
|
||||
model_id=str(forecast.model_id),
|
||||
model_version=forecast.model_version,
|
||||
algorithm=forecast.algorithm,
|
||||
business_type=forecast.business_type,
|
||||
is_holiday=forecast.is_holiday,
|
||||
is_weekend=forecast.is_weekend,
|
||||
day_of_week=forecast.day_of_week,
|
||||
weather_temperature=forecast.weather_temperature,
|
||||
weather_precipitation=forecast.weather_precipitation,
|
||||
weather_description=forecast.weather_description,
|
||||
traffic_volume=forecast.traffic_volume,
|
||||
created_at=forecast.created_at,
|
||||
processing_time_ms=forecast.processing_time_ms,
|
||||
features_used=forecast.features_used
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error creating single forecast", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.post("/batch", response_model=BatchForecastResponse)
|
||||
async def create_batch_forecast(
|
||||
request: BatchForecastRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
):
|
||||
"""Generate batch forecasts for multiple products"""
|
||||
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(request.tenant_id) != str(current_user.get("tenant_id")):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this tenant"
|
||||
)
|
||||
|
||||
# Generate batch forecast
|
||||
batch = await forecasting_service.generate_batch_forecast(request, db)
|
||||
|
||||
# Get associated forecasts
|
||||
forecasts = await forecasting_service.get_forecasts(
|
||||
tenant_id=request.tenant_id,
|
||||
location=request.location,
|
||||
db=db
|
||||
)
|
||||
|
||||
# Convert forecasts to response models
|
||||
forecast_responses = []
|
||||
for forecast in forecasts[:batch.total_products]: # Limit to batch size
|
||||
forecast_responses.append(ForecastResponse(
|
||||
id=str(forecast.id),
|
||||
tenant_id=str(forecast.tenant_id),
|
||||
product_name=forecast.product_name,
|
||||
location=forecast.location,
|
||||
forecast_date=forecast.forecast_date,
|
||||
predicted_demand=forecast.predicted_demand,
|
||||
confidence_lower=forecast.confidence_lower,
|
||||
confidence_upper=forecast.confidence_upper,
|
||||
confidence_level=forecast.confidence_level,
|
||||
model_id=str(forecast.model_id),
|
||||
model_version=forecast.model_version,
|
||||
algorithm=forecast.algorithm,
|
||||
business_type=forecast.business_type,
|
||||
is_holiday=forecast.is_holiday,
|
||||
is_weekend=forecast.is_weekend,
|
||||
day_of_week=forecast.day_of_week,
|
||||
weather_temperature=forecast.weather_temperature,
|
||||
weather_precipitation=forecast.weather_precipitation,
|
||||
weather_description=forecast.weather_description,
|
||||
traffic_volume=forecast.traffic_volume,
|
||||
created_at=forecast.created_at,
|
||||
processing_time_ms=forecast.processing_time_ms,
|
||||
features_used=forecast.features_used
|
||||
))
|
||||
|
||||
return BatchForecastResponse(
|
||||
id=str(batch.id),
|
||||
tenant_id=str(batch.tenant_id),
|
||||
batch_name=batch.batch_name,
|
||||
status=batch.status,
|
||||
total_products=batch.total_products,
|
||||
completed_products=batch.completed_products,
|
||||
failed_products=batch.failed_products,
|
||||
requested_at=batch.requested_at,
|
||||
completed_at=batch.completed_at,
|
||||
processing_time_ms=batch.processing_time_ms,
|
||||
forecasts=forecast_responses
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error creating batch forecast", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.get("/list", response_model=List[ForecastResponse])
|
||||
async def list_forecasts(
|
||||
location: str,
|
||||
start_date: Optional[date] = Query(None),
|
||||
end_date: Optional[date] = Query(None),
|
||||
product_name: Optional[str] = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
):
|
||||
"""List forecasts with filtering"""
|
||||
|
||||
try:
|
||||
tenant_id = str(current_user.get("tenant_id"))
|
||||
|
||||
# Get forecasts
|
||||
forecasts = await forecasting_service.get_forecasts(
|
||||
tenant_id=tenant_id,
|
||||
location=location,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_name=product_name,
|
||||
db=db
|
||||
)
|
||||
|
||||
# Convert to response models
|
||||
return [
|
||||
ForecastResponse(
|
||||
id=str(forecast.id),
|
||||
tenant_id=str(forecast.tenant_id),
|
||||
product_name=forecast.product_name,
|
||||
location=forecast.location,
|
||||
forecast_date=forecast.forecast_date,
|
||||
predicted_demand=forecast.predicted_demand,
|
||||
confidence_lower=forecast.confidence_lower,
|
||||
confidence_upper=forecast.confidence_upper,
|
||||
confidence_level=forecast.confidence_level,
|
||||
model_id=str(forecast.model_id),
|
||||
model_version=forecast.model_version,
|
||||
algorithm=forecast.algorithm,
|
||||
business_type=forecast.business_type,
|
||||
is_holiday=forecast.is_holiday,
|
||||
is_weekend=forecast.is_weekend,
|
||||
day_of_week=forecast.day_of_week,
|
||||
weather_temperature=forecast.weather_temperature,
|
||||
weather_precipitation=forecast.weather_precipitation,
|
||||
weather_description=forecast.weather_description,
|
||||
traffic_volume=forecast.traffic_volume,
|
||||
created_at=forecast.created_at,
|
||||
processing_time_ms=forecast.processing_time_ms,
|
||||
features_used=forecast.features_used
|
||||
)
|
||||
for forecast in forecasts
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error listing forecasts", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.get("/alerts", response_model=List[AlertResponse])
|
||||
async def get_forecast_alerts(
|
||||
active_only: bool = Query(True),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
):
|
||||
"""Get forecast alerts for tenant"""
|
||||
|
||||
try:
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
tenant_id = current_user.get("tenant_id")
|
||||
|
||||
# Build query
|
||||
query = select(ForecastAlert).where(
|
||||
ForecastAlert.tenant_id == tenant_id
|
||||
)
|
||||
|
||||
if active_only:
|
||||
query = query.where(ForecastAlert.is_active == True)
|
||||
|
||||
query = query.order_by(ForecastAlert.created_at.desc())
|
||||
|
||||
# Execute query
|
||||
result = await db.execute(query)
|
||||
alerts = result.scalars().all()
|
||||
|
||||
# Convert to response models
|
||||
return [
|
||||
AlertResponse(
|
||||
id=str(alert.id),
|
||||
tenant_id=str(alert.tenant_id),
|
||||
forecast_id=str(alert.forecast_id),
|
||||
alert_type=alert.alert_type,
|
||||
severity=alert.severity,
|
||||
message=alert.message,
|
||||
is_active=alert.is_active,
|
||||
created_at=alert.created_at,
|
||||
acknowledged_at=alert.acknowledged_at,
|
||||
notification_sent=alert.notification_sent
|
||||
)
|
||||
for alert in alerts
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting forecast alerts", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.put("/alerts/{alert_id}/acknowledge")
|
||||
async def acknowledge_alert(
|
||||
alert_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
):
|
||||
"""Acknowledge a forecast alert"""
|
||||
|
||||
try:
|
||||
from sqlalchemy import select, update
|
||||
from datetime import datetime
|
||||
|
||||
tenant_id = current_user.get("tenant_id")
|
||||
|
||||
# Get alert
|
||||
result = await db.execute(
|
||||
select(ForecastAlert).where(
|
||||
and_(
|
||||
ForecastAlert.id == alert_id,
|
||||
ForecastAlert.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
)
|
||||
alert = result.scalar_one_or_none()
|
||||
|
||||
if not alert:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Alert not found"
|
||||
)
|
||||
|
||||
# Update alert
|
||||
alert.acknowledged_at = datetime.now()
|
||||
alert.is_active = False
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Alert acknowledged successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error acknowledging alert", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
Reference in New Issue
Block a user