142 lines
4.7 KiB
Python
142 lines
4.7 KiB
Python
# ================================================================
|
|
# services/forecasting/app/api/predictions.py
|
|
# ================================================================
|
|
"""
|
|
Prediction API endpoints - Real-time prediction capabilities
|
|
"""
|
|
|
|
import structlog
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from typing import List, Dict, Any
|
|
from datetime import date, datetime, timedelta
|
|
|
|
from app.core.database import get_db
|
|
from app.core.auth import get_current_user_from_headers
|
|
from app.services.prediction_service import PredictionService
|
|
from app.schemas.forecasts import ForecastRequest
|
|
|
|
logger = structlog.get_logger()
|
|
router = APIRouter()
|
|
|
|
# Initialize service
|
|
prediction_service = PredictionService()
|
|
|
|
@router.post("/realtime")
|
|
async def get_realtime_prediction(
|
|
product_name: str,
|
|
location: str,
|
|
forecast_date: date,
|
|
features: Dict[str, Any],
|
|
current_user: dict = Depends(get_current_user_from_headers)
|
|
):
|
|
"""Get real-time prediction without storing in database"""
|
|
|
|
try:
|
|
tenant_id = str(current_user.get("tenant_id"))
|
|
|
|
# Get latest model
|
|
from app.services.forecasting_service import ForecastingService
|
|
forecasting_service = ForecastingService()
|
|
|
|
model_info = await forecasting_service._get_latest_model(
|
|
tenant_id, product_name, location
|
|
)
|
|
|
|
if not model_info:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"No trained model found for {product_name}"
|
|
)
|
|
|
|
# Generate prediction
|
|
prediction = await prediction_service.predict(
|
|
model_id=model_info["model_id"],
|
|
features=features,
|
|
confidence_level=0.8
|
|
)
|
|
|
|
return {
|
|
"product_name": product_name,
|
|
"location": location,
|
|
"forecast_date": forecast_date,
|
|
"predicted_demand": prediction["demand"],
|
|
"confidence_lower": prediction["lower_bound"],
|
|
"confidence_upper": prediction["upper_bound"],
|
|
"model_id": model_info["model_id"],
|
|
"model_version": model_info["version"],
|
|
"generated_at": datetime.now(),
|
|
"features_used": features
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Error getting realtime prediction", error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Internal server error"
|
|
)
|
|
|
|
@router.get("/quick/{product_name}")
|
|
async def get_quick_prediction(
|
|
product_name: str,
|
|
location: str = Query(...),
|
|
days_ahead: int = Query(1, ge=1, le=7),
|
|
current_user: dict = Depends(get_current_user_from_headers)
|
|
):
|
|
"""Get quick prediction for next few days"""
|
|
|
|
try:
|
|
tenant_id = str(current_user.get("tenant_id"))
|
|
|
|
# Generate predictions for the next N days
|
|
predictions = []
|
|
|
|
for day in range(1, days_ahead + 1):
|
|
forecast_date = date.today() + timedelta(days=day)
|
|
|
|
# Prepare basic features
|
|
features = {
|
|
"date": forecast_date.isoformat(),
|
|
"day_of_week": forecast_date.weekday(),
|
|
"is_weekend": forecast_date.weekday() >= 5,
|
|
"business_type": "individual"
|
|
}
|
|
|
|
# Get model and predict
|
|
from app.services.forecasting_service import ForecastingService
|
|
forecasting_service = ForecastingService()
|
|
|
|
model_info = await forecasting_service._get_latest_model(
|
|
tenant_id, product_name, location
|
|
)
|
|
|
|
if model_info:
|
|
prediction = await prediction_service.predict(
|
|
model_id=model_info["model_id"],
|
|
features=features
|
|
)
|
|
|
|
predictions.append({
|
|
"date": forecast_date,
|
|
"predicted_demand": prediction["demand"],
|
|
"confidence_lower": prediction["lower_bound"],
|
|
"confidence_upper": prediction["upper_bound"]
|
|
})
|
|
|
|
return {
|
|
"product_name": product_name,
|
|
"location": location,
|
|
"predictions": predictions,
|
|
"generated_at": datetime.now()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Error getting quick prediction", error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Internal server error"
|
|
)
|
|
|