2025-07-21 19:48:56 +02:00
|
|
|
# ================================================================
|
|
|
|
|
# services/forecasting/app/api/predictions.py
|
|
|
|
|
# ================================================================
|
|
|
|
|
"""
|
|
|
|
|
Prediction API endpoints - Real-time prediction capabilities
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import structlog
|
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
from typing import List, Dict, Any
|
|
|
|
|
from datetime import date, datetime, timedelta
|
2025-08-02 17:09:53 +02:00
|
|
|
from sqlalchemy import select, delete, func
|
|
|
|
|
import uuid
|
2025-07-21 19:48:56 +02:00
|
|
|
|
|
|
|
|
from app.core.database import get_db
|
2025-07-21 20:43:17 +02:00
|
|
|
from shared.auth.decorators import (
|
|
|
|
|
get_current_user_dep,
|
2025-08-02 17:09:53 +02:00
|
|
|
get_current_tenant_id_dep,
|
|
|
|
|
get_current_user_dep,
|
|
|
|
|
require_admin_role
|
2025-07-21 20:43:17 +02:00
|
|
|
)
|
2025-07-21 19:48:56 +02:00
|
|
|
from app.services.prediction_service import PredictionService
|
|
|
|
|
from app.schemas.forecasts import ForecastRequest
|
|
|
|
|
|
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
# Initialize service
|
|
|
|
|
prediction_service = PredictionService()
|
|
|
|
|
|
|
|
|
|
@router.post("/realtime")
|
|
|
|
|
async def get_realtime_prediction(
|
|
|
|
|
product_name: str,
|
|
|
|
|
location: str,
|
|
|
|
|
forecast_date: date,
|
|
|
|
|
features: Dict[str, Any],
|
2025-07-21 20:43:17 +02:00
|
|
|
tenant_id: str = Depends(get_current_tenant_id_dep)
|
2025-07-21 19:48:56 +02:00
|
|
|
):
|
|
|
|
|
"""Get real-time prediction without storing in database"""
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
|
|
# Get latest model
|
|
|
|
|
from app.services.forecasting_service import ForecastingService
|
|
|
|
|
forecasting_service = ForecastingService()
|
|
|
|
|
|
|
|
|
|
model_info = await forecasting_service._get_latest_model(
|
|
|
|
|
tenant_id, product_name, location
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not model_info:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
|
|
|
detail=f"No trained model found for {product_name}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Generate prediction
|
|
|
|
|
prediction = await prediction_service.predict(
|
|
|
|
|
model_id=model_info["model_id"],
|
|
|
|
|
features=features,
|
|
|
|
|
confidence_level=0.8
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"product_name": product_name,
|
|
|
|
|
"location": location,
|
|
|
|
|
"forecast_date": forecast_date,
|
|
|
|
|
"predicted_demand": prediction["demand"],
|
|
|
|
|
"confidence_lower": prediction["lower_bound"],
|
|
|
|
|
"confidence_upper": prediction["upper_bound"],
|
|
|
|
|
"model_id": model_info["model_id"],
|
|
|
|
|
"model_version": model_info["version"],
|
|
|
|
|
"generated_at": datetime.now(),
|
|
|
|
|
"features_used": features
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
except HTTPException:
|
|
|
|
|
raise
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Error getting realtime prediction", error=str(e))
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
detail="Internal server error"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@router.get("/quick/{product_name}")
|
|
|
|
|
async def get_quick_prediction(
|
|
|
|
|
product_name: str,
|
|
|
|
|
location: str = Query(...),
|
|
|
|
|
days_ahead: int = Query(1, ge=1, le=7),
|
2025-07-21 20:43:17 +02:00
|
|
|
tenant_id: str = Depends(get_current_tenant_id_dep)
|
2025-07-21 19:48:56 +02:00
|
|
|
):
|
|
|
|
|
"""Get quick prediction for next few days"""
|
|
|
|
|
|
|
|
|
|
try:
|
2025-07-21 20:43:17 +02:00
|
|
|
|
2025-07-21 19:48:56 +02:00
|
|
|
# Generate predictions for the next N days
|
|
|
|
|
predictions = []
|
|
|
|
|
|
|
|
|
|
for day in range(1, days_ahead + 1):
|
|
|
|
|
forecast_date = date.today() + timedelta(days=day)
|
|
|
|
|
|
|
|
|
|
# Prepare basic features
|
|
|
|
|
features = {
|
|
|
|
|
"date": forecast_date.isoformat(),
|
|
|
|
|
"day_of_week": forecast_date.weekday(),
|
|
|
|
|
"is_weekend": forecast_date.weekday() >= 5,
|
|
|
|
|
"business_type": "individual"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Get model and predict
|
|
|
|
|
from app.services.forecasting_service import ForecastingService
|
|
|
|
|
forecasting_service = ForecastingService()
|
|
|
|
|
|
|
|
|
|
model_info = await forecasting_service._get_latest_model(
|
|
|
|
|
tenant_id, product_name, location
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if model_info:
|
|
|
|
|
prediction = await prediction_service.predict(
|
|
|
|
|
model_id=model_info["model_id"],
|
|
|
|
|
features=features
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
predictions.append({
|
|
|
|
|
"date": forecast_date,
|
|
|
|
|
"predicted_demand": prediction["demand"],
|
|
|
|
|
"confidence_lower": prediction["lower_bound"],
|
|
|
|
|
"confidence_upper": prediction["upper_bound"]
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"product_name": product_name,
|
|
|
|
|
"location": location,
|
|
|
|
|
"predictions": predictions,
|
|
|
|
|
"generated_at": datetime.now()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Error getting quick prediction", error=str(e))
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
detail="Internal server error"
|
|
|
|
|
)
|
|
|
|
|
|
2025-08-02 17:09:53 +02:00
|
|
|
@router.post("/tenants/{tenant_id}/predictions/cancel-batches")
|
|
|
|
|
async def cancel_tenant_prediction_batches(
|
|
|
|
|
tenant_id: str,
|
|
|
|
|
current_user = Depends(get_current_user_dep),
|
|
|
|
|
_admin_check = Depends(require_admin_role),
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""Cancel all active prediction batches for a tenant (admin only)"""
|
|
|
|
|
try:
|
|
|
|
|
tenant_uuid = uuid.UUID(tenant_id)
|
|
|
|
|
except ValueError:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
|
detail="Invalid tenant ID format"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from app.models.forecasts import PredictionBatch
|
|
|
|
|
|
|
|
|
|
# Find active prediction batches
|
|
|
|
|
active_batches_query = select(PredictionBatch).where(
|
|
|
|
|
PredictionBatch.tenant_id == tenant_uuid,
|
|
|
|
|
PredictionBatch.status.in_(["queued", "running", "pending"])
|
|
|
|
|
)
|
|
|
|
|
active_batches_result = await db.execute(active_batches_query)
|
|
|
|
|
active_batches = active_batches_result.scalars().all()
|
|
|
|
|
|
|
|
|
|
batches_cancelled = 0
|
|
|
|
|
cancelled_batch_ids = []
|
|
|
|
|
errors = []
|
|
|
|
|
|
|
|
|
|
for batch in active_batches:
|
|
|
|
|
try:
|
|
|
|
|
batch.status = "cancelled"
|
|
|
|
|
batch.updated_at = datetime.utcnow()
|
|
|
|
|
batch.cancelled_by = current_user.get("user_id")
|
|
|
|
|
batches_cancelled += 1
|
|
|
|
|
cancelled_batch_ids.append(str(batch.id))
|
|
|
|
|
|
|
|
|
|
logger.info("Cancelled prediction batch",
|
|
|
|
|
batch_id=str(batch.id),
|
|
|
|
|
tenant_id=tenant_id)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
error_msg = f"Failed to cancel batch {batch.id}: {str(e)}"
|
|
|
|
|
errors.append(error_msg)
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
|
|
|
|
if batches_cancelled > 0:
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"success": True,
|
|
|
|
|
"tenant_id": tenant_id,
|
|
|
|
|
"batches_cancelled": batches_cancelled,
|
|
|
|
|
"cancelled_batch_ids": cancelled_batch_ids,
|
|
|
|
|
"errors": errors,
|
|
|
|
|
"cancelled_at": datetime.utcnow().isoformat()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
await db.rollback()
|
|
|
|
|
logger.error("Failed to cancel tenant prediction batches",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
detail="Failed to cancel prediction batches"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@router.delete("/tenants/{tenant_id}/predictions/cache")
|
|
|
|
|
async def clear_tenant_prediction_cache(
|
|
|
|
|
tenant_id: str,
|
|
|
|
|
current_user = Depends(get_current_user_dep),
|
|
|
|
|
_admin_check = Depends(require_admin_role),
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""Clear all prediction cache for a tenant (admin only)"""
|
|
|
|
|
try:
|
|
|
|
|
tenant_uuid = uuid.UUID(tenant_id)
|
|
|
|
|
except ValueError:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
|
detail="Invalid tenant ID format"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from app.models.forecasts import PredictionCache
|
|
|
|
|
|
|
|
|
|
# Count cache entries before deletion
|
|
|
|
|
cache_count_query = select(func.count(PredictionCache.id)).where(
|
|
|
|
|
PredictionCache.tenant_id == tenant_uuid
|
|
|
|
|
)
|
|
|
|
|
cache_count_result = await db.execute(cache_count_query)
|
|
|
|
|
cache_count = cache_count_result.scalar()
|
|
|
|
|
|
|
|
|
|
# Delete cache entries
|
|
|
|
|
cache_delete_query = delete(PredictionCache).where(
|
|
|
|
|
PredictionCache.tenant_id == tenant_uuid
|
|
|
|
|
)
|
|
|
|
|
cache_delete_result = await db.execute(cache_delete_query)
|
|
|
|
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
logger.info("Cleared tenant prediction cache",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
cache_cleared=cache_delete_result.rowcount)
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"success": True,
|
|
|
|
|
"tenant_id": tenant_id,
|
|
|
|
|
"cache_cleared": cache_delete_result.rowcount,
|
|
|
|
|
"expected_count": cache_count,
|
|
|
|
|
"cleared_at": datetime.utcnow().isoformat()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
await db.rollback()
|
|
|
|
|
logger.error("Failed to clear tenant prediction cache",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
|
|
detail="Failed to clear prediction cache"
|
|
|
|
|
)
|