Add forecasting service
This commit is contained in:
141
services/forecasting/app/api/predictions.py
Normal file
141
services/forecasting/app/api/predictions.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/api/predictions.py
|
||||
# ================================================================
|
||||
"""
|
||||
Prediction API endpoints - Real-time prediction capabilities
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Dict, Any
|
||||
from datetime import date, datetime, timedelta
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import get_current_user_from_headers
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.schemas.forecasts import ForecastRequest
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize service
|
||||
prediction_service = PredictionService()
|
||||
|
||||
@router.post("/realtime")
|
||||
async def get_realtime_prediction(
|
||||
product_name: str,
|
||||
location: str,
|
||||
forecast_date: date,
|
||||
features: Dict[str, Any],
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
):
|
||||
"""Get real-time prediction without storing in database"""
|
||||
|
||||
try:
|
||||
tenant_id = str(current_user.get("tenant_id"))
|
||||
|
||||
# Get latest model
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
forecasting_service = ForecastingService()
|
||||
|
||||
model_info = await forecasting_service._get_latest_model(
|
||||
tenant_id, product_name, location
|
||||
)
|
||||
|
||||
if not model_info:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No trained model found for {product_name}"
|
||||
)
|
||||
|
||||
# Generate prediction
|
||||
prediction = await prediction_service.predict(
|
||||
model_id=model_info["model_id"],
|
||||
features=features,
|
||||
confidence_level=0.8
|
||||
)
|
||||
|
||||
return {
|
||||
"product_name": product_name,
|
||||
"location": location,
|
||||
"forecast_date": forecast_date,
|
||||
"predicted_demand": prediction["demand"],
|
||||
"confidence_lower": prediction["lower_bound"],
|
||||
"confidence_upper": prediction["upper_bound"],
|
||||
"model_id": model_info["model_id"],
|
||||
"model_version": model_info["version"],
|
||||
"generated_at": datetime.now(),
|
||||
"features_used": features
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error getting realtime prediction", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.get("/quick/{product_name}")
|
||||
async def get_quick_prediction(
|
||||
product_name: str,
|
||||
location: str = Query(...),
|
||||
days_ahead: int = Query(1, ge=1, le=7),
|
||||
current_user: dict = Depends(get_current_user_from_headers)
|
||||
):
|
||||
"""Get quick prediction for next few days"""
|
||||
|
||||
try:
|
||||
tenant_id = str(current_user.get("tenant_id"))
|
||||
|
||||
# Generate predictions for the next N days
|
||||
predictions = []
|
||||
|
||||
for day in range(1, days_ahead + 1):
|
||||
forecast_date = date.today() + timedelta(days=day)
|
||||
|
||||
# Prepare basic features
|
||||
features = {
|
||||
"date": forecast_date.isoformat(),
|
||||
"day_of_week": forecast_date.weekday(),
|
||||
"is_weekend": forecast_date.weekday() >= 5,
|
||||
"business_type": "individual"
|
||||
}
|
||||
|
||||
# Get model and predict
|
||||
from app.services.forecasting_service import ForecastingService
|
||||
forecasting_service = ForecastingService()
|
||||
|
||||
model_info = await forecasting_service._get_latest_model(
|
||||
tenant_id, product_name, location
|
||||
)
|
||||
|
||||
if model_info:
|
||||
prediction = await prediction_service.predict(
|
||||
model_id=model_info["model_id"],
|
||||
features=features
|
||||
)
|
||||
|
||||
predictions.append({
|
||||
"date": forecast_date,
|
||||
"predicted_demand": prediction["demand"],
|
||||
"confidence_lower": prediction["lower_bound"],
|
||||
"confidence_upper": prediction["upper_bound"]
|
||||
})
|
||||
|
||||
return {
|
||||
"product_name": product_name,
|
||||
"location": location,
|
||||
"predictions": predictions,
|
||||
"generated_at": datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting quick prediction", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user