# ================================================================ # services/forecasting/app/api/predictions.py # ================================================================ """ Prediction API endpoints - Real-time prediction capabilities """ import structlog from fastapi import APIRouter, Depends, HTTPException, status, Query from sqlalchemy.ext.asyncio import AsyncSession from typing import List, Dict, Any from datetime import date, datetime, timedelta from sqlalchemy import select, delete, func import uuid from app.core.database import get_db from shared.auth.decorators import ( get_current_user_dep, get_current_tenant_id_dep, get_current_user_dep, require_admin_role ) from app.services.prediction_service import PredictionService from app.schemas.forecasts import ForecastRequest logger = structlog.get_logger() router = APIRouter() # Initialize service prediction_service = PredictionService() @router.post("/realtime") async def get_realtime_prediction( product_name: str, location: str, forecast_date: date, features: Dict[str, Any], tenant_id: str = Depends(get_current_tenant_id_dep) ): """Get real-time prediction without storing in database""" try: # Get latest model from app.services.forecasting_service import ForecastingService forecasting_service = ForecastingService() model_info = await forecasting_service._get_latest_model( tenant_id, product_name, location ) if not model_info: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"No trained model found for {product_name}" ) # Generate prediction prediction = await prediction_service.predict( model_id=model_info["model_id"], features=features, confidence_level=0.8 ) return { "product_name": product_name, "location": location, "forecast_date": forecast_date, "predicted_demand": prediction["demand"], "confidence_lower": prediction["lower_bound"], "confidence_upper": prediction["upper_bound"], "model_id": model_info["model_id"], "model_version": model_info["version"], "generated_at": datetime.now(), "features_used": features } except HTTPException: raise except Exception as e: logger.error("Error getting realtime prediction", error=str(e)) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error" ) @router.get("/quick/{product_name}") async def get_quick_prediction( product_name: str, location: str = Query(...), days_ahead: int = Query(1, ge=1, le=7), tenant_id: str = Depends(get_current_tenant_id_dep) ): """Get quick prediction for next few days""" try: # Generate predictions for the next N days predictions = [] for day in range(1, days_ahead + 1): forecast_date = date.today() + timedelta(days=day) # Prepare basic features features = { "date": forecast_date.isoformat(), "day_of_week": forecast_date.weekday(), "is_weekend": forecast_date.weekday() >= 5, "business_type": "individual" } # Get model and predict from app.services.forecasting_service import ForecastingService forecasting_service = ForecastingService() model_info = await forecasting_service._get_latest_model( tenant_id, product_name, location ) if model_info: prediction = await prediction_service.predict( model_id=model_info["model_id"], features=features ) predictions.append({ "date": forecast_date, "predicted_demand": prediction["demand"], "confidence_lower": prediction["lower_bound"], "confidence_upper": prediction["upper_bound"] }) return { "product_name": product_name, "location": location, "predictions": predictions, "generated_at": datetime.now() } except Exception as e: logger.error("Error getting quick prediction", error=str(e)) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error" ) @router.post("/tenants/{tenant_id}/predictions/cancel-batches") async def cancel_tenant_prediction_batches( tenant_id: str, current_user = Depends(get_current_user_dep), _admin_check = Depends(require_admin_role), db: AsyncSession = Depends(get_db) ): """Cancel all active prediction batches for a tenant (admin only)""" try: tenant_uuid = uuid.UUID(tenant_id) except ValueError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tenant ID format" ) try: from app.models.forecasts import PredictionBatch # Find active prediction batches active_batches_query = select(PredictionBatch).where( PredictionBatch.tenant_id == tenant_uuid, PredictionBatch.status.in_(["queued", "running", "pending"]) ) active_batches_result = await db.execute(active_batches_query) active_batches = active_batches_result.scalars().all() batches_cancelled = 0 cancelled_batch_ids = [] errors = [] for batch in active_batches: try: batch.status = "cancelled" batch.updated_at = datetime.utcnow() batch.cancelled_by = current_user.get("user_id") batches_cancelled += 1 cancelled_batch_ids.append(str(batch.id)) logger.info("Cancelled prediction batch", batch_id=str(batch.id), tenant_id=tenant_id) except Exception as e: error_msg = f"Failed to cancel batch {batch.id}: {str(e)}" errors.append(error_msg) logger.error(error_msg) if batches_cancelled > 0: await db.commit() return { "success": True, "tenant_id": tenant_id, "batches_cancelled": batches_cancelled, "cancelled_batch_ids": cancelled_batch_ids, "errors": errors, "cancelled_at": datetime.utcnow().isoformat() } except Exception as e: await db.rollback() logger.error("Failed to cancel tenant prediction batches", tenant_id=tenant_id, error=str(e)) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to cancel prediction batches" ) @router.delete("/tenants/{tenant_id}/predictions/cache") async def clear_tenant_prediction_cache( tenant_id: str, current_user = Depends(get_current_user_dep), _admin_check = Depends(require_admin_role), db: AsyncSession = Depends(get_db) ): """Clear all prediction cache for a tenant (admin only)""" try: tenant_uuid = uuid.UUID(tenant_id) except ValueError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tenant ID format" ) try: from app.models.forecasts import PredictionCache # Count cache entries before deletion cache_count_query = select(func.count(PredictionCache.id)).where( PredictionCache.tenant_id == tenant_uuid ) cache_count_result = await db.execute(cache_count_query) cache_count = cache_count_result.scalar() # Delete cache entries cache_delete_query = delete(PredictionCache).where( PredictionCache.tenant_id == tenant_uuid ) cache_delete_result = await db.execute(cache_delete_query) await db.commit() logger.info("Cleared tenant prediction cache", tenant_id=tenant_id, cache_cleared=cache_delete_result.rowcount) return { "success": True, "tenant_id": tenant_id, "cache_cleared": cache_delete_result.rowcount, "expected_count": cache_count, "cleared_at": datetime.utcnow().isoformat() } except Exception as e: await db.rollback() logger.error("Failed to clear tenant prediction cache", tenant_id=tenant_id, error=str(e)) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to clear prediction cache" )