1202 lines
56 KiB
Python
1202 lines
56 KiB
Python
"""
|
|
Enhanced Forecasting Service with Repository Pattern
|
|
Main forecasting service that uses the repository pattern for data access
|
|
"""
|
|
|
|
import structlog
|
|
import uuid
|
|
import asyncio
|
|
from typing import Dict, List, Any, Optional
|
|
from datetime import datetime, date, timedelta, timezone
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.ml.predictor import BakeryForecaster
|
|
from app.schemas.forecasts import ForecastRequest, ForecastResponse
|
|
from app.services.prediction_service import PredictionService
|
|
from app.services.model_client import ModelClient
|
|
from app.services.data_client import DataClient
|
|
from app.utils.distributed_lock import get_forecast_lock, get_batch_forecast_lock, LockAcquisitionError
|
|
|
|
# Import repositories
|
|
from app.repositories import (
|
|
ForecastRepository,
|
|
PredictionBatchRepository,
|
|
PerformanceMetricRepository,
|
|
PredictionCacheRepository
|
|
)
|
|
|
|
# Import shared database components
|
|
from shared.database.base import create_database_manager
|
|
from shared.database.unit_of_work import UnitOfWork
|
|
from shared.database.transactions import transactional
|
|
from shared.database.exceptions import DatabaseError
|
|
from app.core.config import settings
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class EnhancedForecastingService:
|
|
"""
|
|
Enhanced forecasting service using repository pattern.
|
|
Handles forecast generation, batch processing with proper data abstraction.
|
|
"""
|
|
|
|
def __init__(self, database_manager=None):
|
|
self.database_manager = database_manager or create_database_manager(
|
|
settings.DATABASE_URL, "forecasting-service"
|
|
)
|
|
|
|
# Initialize ML components
|
|
self.forecaster = BakeryForecaster(database_manager=self.database_manager)
|
|
self.prediction_service = PredictionService(database_manager=self.database_manager)
|
|
self.model_client = ModelClient(database_manager=self.database_manager)
|
|
self.data_client = DataClient()
|
|
|
|
async def _init_repositories(self, session):
|
|
"""Initialize repositories with session"""
|
|
return {
|
|
'forecast': ForecastRepository(session),
|
|
'batch': PredictionBatchRepository(session),
|
|
'performance': PerformanceMetricRepository(session),
|
|
'cache': PredictionCacheRepository(session)
|
|
}
|
|
|
|
async def generate_batch_forecasts(self, tenant_id: str, request) -> Dict[str, Any]:
|
|
"""Generate batch forecasts using repository pattern"""
|
|
try:
|
|
# Implementation would use repository pattern to generate multiple forecasts
|
|
batch_uuid = uuid.uuid4()
|
|
return {
|
|
"id": str(batch_uuid), # UUID for database references
|
|
"batch_id": f"batch_{tenant_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", # Human-readable batch identifier
|
|
"tenant_id": tenant_id,
|
|
"forecasts": [],
|
|
"total_forecasts": 0,
|
|
"successful_forecasts": 0,
|
|
"failed_forecasts": 0,
|
|
"enhanced_features": True,
|
|
"repository_integration": True
|
|
}
|
|
except Exception as e:
|
|
logger.error("Batch forecast generation failed", error=str(e))
|
|
raise
|
|
|
|
async def get_tenant_forecasts(self, tenant_id: str, inventory_product_id: str = None,
|
|
start_date: date = None, end_date: date = None,
|
|
skip: int = 0, limit: int = 100) -> List[Dict]:
|
|
"""Get tenant forecasts with filtering"""
|
|
try:
|
|
# Get session and initialize repositories
|
|
async with self.database_manager.get_background_session() as session:
|
|
repos = await self._init_repositories(session)
|
|
|
|
# Build filters
|
|
filters = {"tenant_id": tenant_id}
|
|
if inventory_product_id:
|
|
filters["inventory_product_id"] = inventory_product_id
|
|
|
|
# If date range specified, use specialized method
|
|
if start_date and end_date:
|
|
forecasts = await repos['forecast'].get_forecasts_by_date_range(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
inventory_product_id=inventory_product_id
|
|
)
|
|
else:
|
|
# Use general get_multi with tenant filter
|
|
forecasts = await repos['forecast'].get_multi(
|
|
filters=filters,
|
|
skip=skip,
|
|
limit=limit,
|
|
order_by="forecast_date",
|
|
order_desc=True
|
|
)
|
|
|
|
# Convert to dict format
|
|
forecast_list = []
|
|
for forecast in forecasts:
|
|
forecast_dict = {
|
|
"id": str(forecast.id),
|
|
"tenant_id": str(forecast.tenant_id),
|
|
"inventory_product_id": forecast.inventory_product_id,
|
|
"location": forecast.location,
|
|
"forecast_date": forecast.forecast_date.isoformat(),
|
|
"predicted_demand": float(forecast.predicted_demand),
|
|
"confidence_lower": float(forecast.confidence_lower),
|
|
"confidence_upper": float(forecast.confidence_upper),
|
|
"confidence_level": float(forecast.confidence_level),
|
|
"model_id": forecast.model_id,
|
|
"model_version": forecast.model_version,
|
|
"algorithm": forecast.algorithm,
|
|
"business_type": forecast.business_type,
|
|
"is_holiday": forecast.is_holiday,
|
|
"is_weekend": forecast.is_weekend,
|
|
"processing_time_ms": forecast.processing_time_ms,
|
|
"created_at": forecast.created_at.isoformat() if forecast.created_at else None
|
|
}
|
|
forecast_list.append(forecast_dict)
|
|
|
|
logger.info("Retrieved tenant forecasts",
|
|
tenant_id=tenant_id,
|
|
count=len(forecast_list),
|
|
filters=filters)
|
|
|
|
return forecast_list
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get tenant forecasts",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise
|
|
|
|
async def list_forecasts(self, tenant_id: str, inventory_product_id: str = None,
|
|
start_date: date = None, end_date: date = None,
|
|
limit: int = 100, offset: int = 0) -> List[Dict]:
|
|
"""Alias for get_tenant_forecasts for API compatibility"""
|
|
return await self.get_tenant_forecasts(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
skip=offset,
|
|
limit=limit
|
|
)
|
|
|
|
async def get_forecast_by_id(self, forecast_id: str) -> Optional[Dict]:
|
|
"""Get forecast by ID"""
|
|
try:
|
|
async with self.database_manager.get_background_session() as session:
|
|
repos = await self._init_repositories(session)
|
|
forecast = await repos['forecast'].get(forecast_id)
|
|
|
|
if not forecast:
|
|
return None
|
|
|
|
return {
|
|
"id": str(forecast.id),
|
|
"tenant_id": str(forecast.tenant_id),
|
|
"inventory_product_id": str(forecast.inventory_product_id),
|
|
"location": forecast.location,
|
|
"forecast_date": forecast.forecast_date.isoformat(),
|
|
"predicted_demand": float(forecast.predicted_demand),
|
|
"confidence_lower": float(forecast.confidence_lower),
|
|
"confidence_upper": float(forecast.confidence_upper),
|
|
"confidence_level": float(forecast.confidence_level),
|
|
"model_id": forecast.model_id,
|
|
"model_version": forecast.model_version,
|
|
"algorithm": forecast.algorithm
|
|
}
|
|
except Exception as e:
|
|
logger.error("Failed to get forecast by ID", error=str(e))
|
|
raise
|
|
|
|
async def get_forecast(self, tenant_id: str, forecast_id: uuid.UUID) -> Optional[Dict]:
|
|
"""Get forecast by ID with tenant validation"""
|
|
forecast = await self.get_forecast_by_id(str(forecast_id))
|
|
if forecast and forecast["tenant_id"] == tenant_id:
|
|
return forecast
|
|
return None
|
|
|
|
async def delete_forecast(self, tenant_id: str, forecast_id: uuid.UUID) -> bool:
|
|
"""Delete forecast with tenant validation"""
|
|
try:
|
|
async with self.database_manager.get_background_session() as session:
|
|
repos = await self._init_repositories(session)
|
|
|
|
# First verify it belongs to the tenant
|
|
forecast = await repos['forecast'].get(str(forecast_id))
|
|
if not forecast or str(forecast.tenant_id) != tenant_id:
|
|
return False
|
|
|
|
# Delete it
|
|
await repos['forecast'].delete(str(forecast_id))
|
|
await session.commit()
|
|
|
|
logger.info("Forecast deleted", tenant_id=tenant_id, forecast_id=forecast_id)
|
|
return True
|
|
except Exception as e:
|
|
logger.error("Failed to delete forecast", error=str(e), tenant_id=tenant_id)
|
|
return False
|
|
|
|
|
|
async def get_tenant_forecast_statistics(self, tenant_id: str) -> Dict[str, Any]:
|
|
"""Get tenant forecast statistics"""
|
|
try:
|
|
# Implementation would use repository pattern
|
|
return {
|
|
"total_forecasts": 0,
|
|
"active_forecasts": 0,
|
|
"recent_forecasts": 0,
|
|
"accuracy_metrics": {},
|
|
"enhanced_features": True
|
|
}
|
|
except Exception as e:
|
|
logger.error("Failed to get forecast statistics", error=str(e))
|
|
return {"error": str(e)}
|
|
|
|
async def generate_batch_predictions(self, tenant_id: str, batch_request: Dict) -> Dict[str, Any]:
|
|
"""Generate batch predictions"""
|
|
try:
|
|
# Implementation would use repository pattern
|
|
return {
|
|
"batch_id": f"pred_batch_{tenant_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
|
"tenant_id": tenant_id,
|
|
"predictions": [],
|
|
"total_predictions": 0,
|
|
"successful_predictions": 0,
|
|
"failed_predictions": 0,
|
|
"enhanced_features": True
|
|
}
|
|
except Exception as e:
|
|
logger.error("Batch predictions failed", error=str(e))
|
|
raise
|
|
|
|
async def get_cached_predictions(self, tenant_id: str, inventory_product_id: str = None,
|
|
skip: int = 0, limit: int = 100) -> List[Dict]:
|
|
"""Get cached predictions"""
|
|
try:
|
|
# Implementation would use repository pattern
|
|
return []
|
|
except Exception as e:
|
|
logger.error("Failed to get cached predictions", error=str(e))
|
|
raise
|
|
|
|
async def clear_prediction_cache(self, tenant_id: str, inventory_product_id: str = None) -> int:
|
|
"""Clear prediction cache"""
|
|
try:
|
|
# Implementation would use repository pattern
|
|
return 0
|
|
except Exception as e:
|
|
logger.error("Failed to clear prediction cache", error=str(e))
|
|
return 0
|
|
|
|
async def get_prediction_performance(self, tenant_id: str, model_id: str = None,
|
|
start_date: date = None, end_date: date = None) -> Dict[str, Any]:
|
|
"""Get prediction performance metrics"""
|
|
try:
|
|
# Implementation would use repository pattern
|
|
return {
|
|
"accuracy_metrics": {},
|
|
"performance_trends": [],
|
|
"enhanced_features": True
|
|
}
|
|
except Exception as e:
|
|
logger.error("Failed to get prediction performance", error=str(e))
|
|
raise
|
|
|
|
async def generate_forecast(
|
|
self,
|
|
tenant_id: str,
|
|
request: ForecastRequest
|
|
) -> ForecastResponse:
|
|
"""
|
|
Generate forecast using repository pattern with caching.
|
|
|
|
CRITICAL FIXES:
|
|
1. External HTTP calls are performed BEFORE opening database session
|
|
to prevent connection pool exhaustion and blocking.
|
|
2. Advisory locks prevent concurrent forecast generation for same product/date
|
|
to avoid duplicate work and race conditions.
|
|
"""
|
|
start_time = datetime.now(timezone.utc)
|
|
|
|
try:
|
|
logger.info("Generating enhanced forecast",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
date=request.forecast_date.isoformat())
|
|
|
|
# CRITICAL FIX: Get model BEFORE opening database session
|
|
# This prevents holding database connections during potentially slow external API calls
|
|
logger.debug("Fetching model data before opening database session",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id)
|
|
|
|
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
|
|
|
|
if not model_data:
|
|
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
|
|
|
|
logger.debug("Model data fetched successfully",
|
|
tenant_id=tenant_id,
|
|
model_id=model_data.get('model_id'))
|
|
|
|
# Step 3: Prepare features with fallbacks (includes external API calls for weather)
|
|
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
|
|
|
|
# Now open database session AFTER external HTTP calls are complete
|
|
# CRITICAL FIX: Acquire distributed lock to prevent concurrent forecast generation
|
|
async with self.database_manager.get_background_session() as session:
|
|
# Get lock for this specific forecast (tenant + product + date)
|
|
forecast_date_str = request.forecast_date.isoformat().split('T')[0] if hasattr(request.forecast_date, 'isoformat') else str(request.forecast_date).split('T')[0]
|
|
lock = get_forecast_lock(
|
|
tenant_id=tenant_id,
|
|
product_id=str(request.inventory_product_id),
|
|
forecast_date=forecast_date_str
|
|
)
|
|
|
|
try:
|
|
async with lock.acquire(session):
|
|
repos = await self._init_repositories(session)
|
|
|
|
# Step 1: Check cache first (inside lock for consistency)
|
|
# If another request generated the forecast while we waited for the lock,
|
|
# we'll find it in the cache
|
|
cached_prediction = await repos['cache'].get_cached_prediction(
|
|
tenant_id, request.inventory_product_id, request.location, request.forecast_date
|
|
)
|
|
|
|
if cached_prediction:
|
|
logger.info("Found cached prediction after acquiring lock (concurrent request completed first)",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id)
|
|
return self._create_forecast_response_from_cache(cached_prediction)
|
|
|
|
# Step 2: Model data already fetched above (before session opened)
|
|
|
|
# Step 4: Generate prediction (in-memory operation)
|
|
prediction_result = await self.prediction_service.predict(
|
|
model_id=model_data['model_id'],
|
|
model_path=model_data['model_path'],
|
|
features=features,
|
|
confidence_level=request.confidence_level
|
|
)
|
|
|
|
# Step 5: Apply business rules
|
|
adjusted_prediction = self._apply_business_rules(
|
|
prediction_result, request, features
|
|
)
|
|
|
|
# Step 6: Save forecast using repository
|
|
# Convert forecast_date to datetime if it's a string
|
|
forecast_datetime = request.forecast_date
|
|
if isinstance(forecast_datetime, str):
|
|
from dateutil.parser import parse
|
|
forecast_datetime = parse(forecast_datetime)
|
|
|
|
forecast_data = {
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": request.inventory_product_id,
|
|
"product_name": None, # Field is now nullable, use inventory_product_id as reference
|
|
"location": request.location,
|
|
"forecast_date": forecast_datetime,
|
|
"predicted_demand": adjusted_prediction['prediction'],
|
|
"confidence_lower": adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
|
|
"confidence_upper": adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
|
|
"confidence_level": request.confidence_level,
|
|
"model_id": model_data['model_id'],
|
|
"model_version": str(model_data.get('version', '1.0')),
|
|
"algorithm": model_data.get('algorithm', 'prophet'),
|
|
"business_type": features.get('business_type', 'individual'),
|
|
"is_holiday": features.get('is_holiday', False),
|
|
"is_weekend": features.get('is_weekend', False),
|
|
"day_of_week": features.get('day_of_week', 0),
|
|
"weather_temperature": features.get('temperature'),
|
|
"weather_precipitation": features.get('precipitation'),
|
|
"weather_description": features.get('weather_description'),
|
|
"traffic_volume": features.get('traffic_volume'),
|
|
"processing_time_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
|
"features_used": features
|
|
}
|
|
|
|
forecast = await repos['forecast'].create_forecast(forecast_data)
|
|
await session.commit()
|
|
|
|
# Step 7: Cache the prediction
|
|
await repos['cache'].cache_prediction(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
location=request.location,
|
|
forecast_date=forecast_datetime,
|
|
predicted_demand=adjusted_prediction['prediction'],
|
|
confidence_lower=adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
|
|
confidence_upper=adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
|
|
model_id=model_data['model_id'],
|
|
expires_in_hours=24
|
|
)
|
|
|
|
|
|
logger.info("Enhanced forecast generated successfully",
|
|
forecast_id=forecast.id,
|
|
tenant_id=tenant_id,
|
|
prediction=adjusted_prediction['prediction'])
|
|
|
|
return self._create_forecast_response_from_model(forecast)
|
|
|
|
except LockAcquisitionError:
|
|
# Could not acquire lock - another forecast request is in progress
|
|
logger.warning("Could not acquire forecast lock, checking cache for concurrent request result",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
forecast_date=forecast_date_str)
|
|
|
|
# Wait a moment and check cache - maybe the concurrent request finished
|
|
await asyncio.sleep(1)
|
|
|
|
repos = await self._init_repositories(session)
|
|
cached_prediction = await repos['cache'].get_cached_prediction(
|
|
tenant_id, request.inventory_product_id, request.location, request.forecast_date
|
|
)
|
|
|
|
if cached_prediction:
|
|
logger.info("Found forecast in cache after lock timeout (concurrent request completed)",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id)
|
|
return self._create_forecast_response_from_cache(cached_prediction)
|
|
|
|
# No cached result, raise error
|
|
raise ValueError(
|
|
f"Forecast generation already in progress for product {request.inventory_product_id}. "
|
|
"Please try again in a few seconds."
|
|
)
|
|
|
|
except Exception as e:
|
|
processing_time = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
|
logger.error("Error generating enhanced forecast",
|
|
error=str(e),
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
processing_time=processing_time)
|
|
raise
|
|
|
|
async def generate_multi_day_forecast(
|
|
self,
|
|
tenant_id: str,
|
|
request: ForecastRequest
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Generate multiple daily forecasts for the specified period.
|
|
"""
|
|
start_time = datetime.now(timezone.utc)
|
|
forecasts = []
|
|
|
|
try:
|
|
logger.info("Generating multi-day forecast",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
forecast_days=request.forecast_days,
|
|
start_date=request.forecast_date.isoformat())
|
|
|
|
# Fetch weather forecast ONCE for all days to reduce API calls
|
|
weather_forecasts = await self.data_client.fetch_weather_forecast(
|
|
tenant_id=tenant_id,
|
|
days=request.forecast_days,
|
|
latitude=40.4168, # Madrid coordinates (could be parameterized per tenant)
|
|
longitude=-3.7038
|
|
)
|
|
|
|
# Create a mapping of dates to weather data for quick lookup
|
|
weather_map = {}
|
|
for weather in weather_forecasts:
|
|
weather_date = weather.get('forecast_date', '')
|
|
if isinstance(weather_date, str):
|
|
weather_date = weather_date.split('T')[0]
|
|
elif hasattr(weather_date, 'date'):
|
|
weather_date = weather_date.date().isoformat()
|
|
else:
|
|
weather_date = str(weather_date).split('T')[0]
|
|
weather_map[weather_date] = weather
|
|
|
|
# Generate a forecast for each day
|
|
for day_offset in range(request.forecast_days):
|
|
# Calculate the forecast date for this day
|
|
current_date = request.forecast_date
|
|
if isinstance(current_date, str):
|
|
from dateutil.parser import parse
|
|
current_date = parse(current_date).date()
|
|
|
|
if day_offset > 0:
|
|
current_date = current_date + timedelta(days=day_offset)
|
|
|
|
# Create a new request for this specific day
|
|
daily_request = ForecastRequest(
|
|
inventory_product_id=request.inventory_product_id,
|
|
forecast_date=current_date,
|
|
forecast_days=1, # Single day for each iteration
|
|
location=request.location,
|
|
confidence_level=request.confidence_level
|
|
)
|
|
|
|
# Generate forecast for this day, passing the weather data map
|
|
daily_forecast = await self.generate_forecast_with_weather_map(tenant_id, daily_request, weather_map)
|
|
forecasts.append(daily_forecast)
|
|
|
|
# Calculate summary statistics
|
|
total_demand = sum(f.predicted_demand for f in forecasts)
|
|
avg_confidence = sum(f.confidence_level for f in forecasts) / len(forecasts)
|
|
processing_time = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
|
|
|
# Convert forecasts to dictionary format for the response
|
|
forecast_dicts = []
|
|
for forecast in forecasts:
|
|
forecast_dicts.append({
|
|
"id": forecast.id,
|
|
"tenant_id": forecast.tenant_id,
|
|
"inventory_product_id": forecast.inventory_product_id,
|
|
"location": forecast.location,
|
|
"forecast_date": forecast.forecast_date.isoformat() if hasattr(forecast.forecast_date, 'isoformat') else str(forecast.forecast_date),
|
|
"predicted_demand": forecast.predicted_demand,
|
|
"confidence_lower": forecast.confidence_lower,
|
|
"confidence_upper": forecast.confidence_upper,
|
|
"confidence_level": forecast.confidence_level,
|
|
"model_id": forecast.model_id,
|
|
"model_version": forecast.model_version,
|
|
"algorithm": forecast.algorithm,
|
|
"business_type": forecast.business_type,
|
|
"is_holiday": forecast.is_holiday,
|
|
"is_weekend": forecast.is_weekend,
|
|
"day_of_week": forecast.day_of_week,
|
|
"weather_temperature": forecast.weather_temperature,
|
|
"weather_precipitation": forecast.weather_precipitation,
|
|
"weather_description": forecast.weather_description,
|
|
"traffic_volume": forecast.traffic_volume,
|
|
"created_at": forecast.created_at.isoformat() if hasattr(forecast.created_at, 'isoformat') else str(forecast.created_at),
|
|
"processing_time_ms": forecast.processing_time_ms,
|
|
"features_used": forecast.features_used
|
|
})
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": request.inventory_product_id,
|
|
"forecast_start_date": request.forecast_date.isoformat() if hasattr(request.forecast_date, 'isoformat') else str(request.forecast_date),
|
|
"forecast_days": request.forecast_days,
|
|
"forecasts": forecast_dicts,
|
|
"total_predicted_demand": total_demand,
|
|
"average_confidence_level": avg_confidence,
|
|
"processing_time_ms": processing_time
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Multi-day forecast generation failed",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise
|
|
|
|
async def generate_forecast_with_weather_map(
|
|
self,
|
|
tenant_id: str,
|
|
request: ForecastRequest,
|
|
weather_map: Dict[str, Any]
|
|
) -> ForecastResponse:
|
|
"""
|
|
Generate forecast using a pre-fetched weather map to avoid multiple API calls.
|
|
"""
|
|
start_time = datetime.now(timezone.utc)
|
|
|
|
try:
|
|
logger.info("Generating enhanced forecast with weather map",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
date=request.forecast_date.isoformat())
|
|
|
|
# CRITICAL FIX: Get model BEFORE opening database session to prevent session blocking during HTTP calls
|
|
# This prevents holding database connections during potentially slow external API calls
|
|
logger.debug("Fetching model data before opening database session",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id)
|
|
|
|
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
|
|
|
|
if not model_data:
|
|
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
|
|
|
|
logger.debug("Model data fetched successfully",
|
|
tenant_id=tenant_id,
|
|
model_id=model_data.get('model_id'))
|
|
|
|
# Prepare features (this doesn't make external HTTP calls when using weather_map)
|
|
features = await self._prepare_forecast_features_with_fallbacks_and_weather_map(tenant_id, request, weather_map)
|
|
|
|
# Now open database session AFTER external HTTP calls are complete
|
|
async with self.database_manager.get_background_session() as session:
|
|
repos = await self._init_repositories(session)
|
|
|
|
# Step 1: Check cache first
|
|
cached_prediction = await repos['cache'].get_cached_prediction(
|
|
tenant_id, request.inventory_product_id, request.location, request.forecast_date
|
|
)
|
|
|
|
if cached_prediction:
|
|
logger.debug("Using cached prediction",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id)
|
|
return self._create_forecast_response_from_cache(cached_prediction)
|
|
|
|
# Step 2: Model data already fetched above (before session opened)
|
|
|
|
# Step 3: Generate prediction
|
|
prediction_result = await self.prediction_service.predict(
|
|
model_id=model_data['model_id'],
|
|
model_path=model_data['model_path'],
|
|
features=features,
|
|
confidence_level=request.confidence_level
|
|
)
|
|
|
|
# Step 4: Apply business rules
|
|
adjusted_prediction = self._apply_business_rules(
|
|
prediction_result, request, features
|
|
)
|
|
|
|
# Step 5: Save forecast using repository
|
|
# Convert forecast_date to datetime if it's a string
|
|
forecast_datetime = request.forecast_date
|
|
if isinstance(forecast_datetime, str):
|
|
from dateutil.parser import parse
|
|
forecast_datetime = parse(forecast_datetime)
|
|
|
|
forecast_data = {
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": request.inventory_product_id,
|
|
"product_name": None, # Field is now nullable, use inventory_product_id as reference
|
|
"location": request.location,
|
|
"forecast_date": forecast_datetime,
|
|
"predicted_demand": adjusted_prediction['prediction'],
|
|
"confidence_lower": adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
|
|
"confidence_upper": adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
|
|
"confidence_level": request.confidence_level,
|
|
"model_id": model_data['model_id'],
|
|
"model_version": str(model_data.get('version', '1.0')),
|
|
"algorithm": model_data.get('algorithm', 'prophet'),
|
|
"business_type": features.get('business_type', 'individual'),
|
|
"is_holiday": features.get('is_holiday', False),
|
|
"is_weekend": features.get('is_weekend', False),
|
|
"day_of_week": features.get('day_of_week', 0),
|
|
"weather_temperature": features.get('temperature'),
|
|
"weather_precipitation": features.get('precipitation'),
|
|
"weather_description": features.get('weather_description'),
|
|
"traffic_volume": features.get('traffic_volume'),
|
|
"processing_time_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
|
"features_used": features
|
|
}
|
|
|
|
forecast = await repos['forecast'].create_forecast(forecast_data)
|
|
|
|
# Step 6: Cache the prediction
|
|
await repos['cache'].cache_prediction(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
location=request.location,
|
|
forecast_date=forecast_datetime,
|
|
predicted_demand=adjusted_prediction['prediction'],
|
|
confidence_lower=adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
|
|
confidence_upper=adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
|
|
model_id=model_data['model_id'],
|
|
expires_in_hours=24
|
|
)
|
|
|
|
|
|
logger.info("Enhanced forecast generated successfully",
|
|
forecast_id=forecast.id,
|
|
tenant_id=tenant_id,
|
|
prediction=adjusted_prediction['prediction'])
|
|
|
|
return self._create_forecast_response_from_model(forecast)
|
|
|
|
except Exception as e:
|
|
processing_time = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
|
logger.error("Error generating enhanced forecast",
|
|
error=str(e),
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
processing_time=processing_time)
|
|
raise
|
|
|
|
async def get_forecast_history(
|
|
self,
|
|
tenant_id: str,
|
|
inventory_product_id: Optional[str] = None,
|
|
start_date: Optional[date] = None,
|
|
end_date: Optional[date] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get forecast history using repository"""
|
|
try:
|
|
async with self.database_manager.get_session() as session:
|
|
repos = await self._init_repositories(session)
|
|
|
|
if start_date and end_date:
|
|
forecasts = await repos['forecast'].get_forecasts_by_date_range(
|
|
tenant_id, start_date, end_date, inventory_product_id
|
|
)
|
|
else:
|
|
# Get recent forecasts (last 30 days)
|
|
forecasts = await repos['forecast'].get_recent_records(
|
|
tenant_id, hours=24*30
|
|
)
|
|
|
|
# Convert to dict format
|
|
return [self._forecast_to_dict(forecast) for forecast in forecasts]
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get forecast history",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
return []
|
|
|
|
async def get_forecast_analytics(self, tenant_id: str) -> Dict[str, Any]:
|
|
"""Get comprehensive forecast analytics using repositories"""
|
|
try:
|
|
async with self.database_manager.get_session() as session:
|
|
repos = await self._init_repositories(session)
|
|
|
|
# Get forecast summary
|
|
forecast_summary = await repos['forecast'].get_forecast_summary(tenant_id)
|
|
|
|
|
|
# Get batch statistics
|
|
batch_stats = await repos['batch'].get_batch_statistics(tenant_id)
|
|
|
|
# Get cache performance
|
|
cache_stats = await repos['cache'].get_cache_statistics(tenant_id)
|
|
|
|
# Get performance trends
|
|
performance_trends = await repos['performance'].get_performance_trends(
|
|
tenant_id, days=30
|
|
)
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"forecast_analytics": forecast_summary,
|
|
"batch_analytics": batch_stats,
|
|
"cache_performance": cache_stats,
|
|
"performance_trends": performance_trends,
|
|
"generated_at": datetime.now(timezone.utc).isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get forecast analytics",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
return {"error": f"Failed to get analytics: {str(e)}"}
|
|
|
|
async def create_batch_prediction(
|
|
self,
|
|
tenant_id: str,
|
|
batch_name: str,
|
|
inventory_product_ids: List[str],
|
|
forecast_days: int = 7
|
|
) -> Dict[str, Any]:
|
|
"""Create batch prediction job using repository"""
|
|
try:
|
|
async with self.database_manager.get_session() as session:
|
|
repos = await self._init_repositories(session)
|
|
|
|
# Create batch record
|
|
batch_data = {
|
|
"tenant_id": tenant_id,
|
|
"batch_name": batch_name,
|
|
"total_products": len(inventory_product_ids),
|
|
"forecast_days": forecast_days,
|
|
"status": "pending"
|
|
}
|
|
|
|
batch = await repos['batch'].create_batch(batch_data)
|
|
|
|
logger.info("Batch prediction created",
|
|
batch_id=batch.id,
|
|
tenant_id=tenant_id,
|
|
total_products=len(inventory_product_ids))
|
|
|
|
return {
|
|
"batch_id": str(batch.id),
|
|
"status": batch.status,
|
|
"total_products": len(inventory_product_ids),
|
|
"created_at": batch.requested_at.isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to create batch prediction",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to create batch: {str(e)}")
|
|
|
|
|
|
def _create_forecast_response_from_cache(self, cache_entry) -> ForecastResponse:
|
|
"""Create forecast response from cached entry"""
|
|
return ForecastResponse(
|
|
id=str(cache_entry.id),
|
|
tenant_id=str(cache_entry.tenant_id),
|
|
inventory_product_id=str(cache_entry.inventory_product_id), # Convert UUID to string
|
|
location=cache_entry.location,
|
|
forecast_date=cache_entry.forecast_date,
|
|
predicted_demand=cache_entry.predicted_demand,
|
|
confidence_lower=cache_entry.confidence_lower,
|
|
confidence_upper=cache_entry.confidence_upper,
|
|
confidence_level=0.8, # Default
|
|
model_id=str(cache_entry.model_id),
|
|
model_version="cached",
|
|
algorithm="cached",
|
|
business_type="individual",
|
|
is_holiday=False,
|
|
is_weekend=cache_entry.forecast_date.weekday() >= 5,
|
|
day_of_week=cache_entry.forecast_date.weekday(),
|
|
weather_temperature=None, # Not stored in cache
|
|
weather_precipitation=None, # Not stored in cache
|
|
weather_description=None, # Not stored in cache
|
|
traffic_volume=None, # Not stored in cache
|
|
created_at=cache_entry.created_at,
|
|
processing_time_ms=0, # From cache
|
|
features_used={}
|
|
)
|
|
|
|
def _create_forecast_response_from_model(self, forecast) -> ForecastResponse:
|
|
"""Create forecast response from forecast model"""
|
|
return ForecastResponse(
|
|
id=str(forecast.id),
|
|
tenant_id=str(forecast.tenant_id),
|
|
inventory_product_id=str(forecast.inventory_product_id), # Convert UUID to string
|
|
location=forecast.location,
|
|
forecast_date=forecast.forecast_date,
|
|
predicted_demand=forecast.predicted_demand,
|
|
confidence_lower=forecast.confidence_lower,
|
|
confidence_upper=forecast.confidence_upper,
|
|
confidence_level=forecast.confidence_level,
|
|
model_id=str(forecast.model_id),
|
|
model_version=forecast.model_version,
|
|
algorithm=forecast.algorithm,
|
|
business_type=forecast.business_type,
|
|
is_holiday=forecast.is_holiday,
|
|
is_weekend=forecast.is_weekend,
|
|
day_of_week=forecast.day_of_week,
|
|
weather_temperature=forecast.weather_temperature,
|
|
weather_precipitation=forecast.weather_precipitation,
|
|
weather_description=forecast.weather_description,
|
|
traffic_volume=forecast.traffic_volume,
|
|
created_at=forecast.created_at,
|
|
processing_time_ms=forecast.processing_time_ms,
|
|
features_used=forecast.features_used
|
|
)
|
|
|
|
def _forecast_to_dict(self, forecast) -> Dict[str, Any]:
|
|
"""Convert forecast model to dictionary"""
|
|
return {
|
|
"id": str(forecast.id),
|
|
"tenant_id": str(forecast.tenant_id),
|
|
"inventory_product_id": str(forecast.inventory_product_id), # Convert UUID to string
|
|
"location": forecast.location,
|
|
"forecast_date": forecast.forecast_date.isoformat(),
|
|
"predicted_demand": forecast.predicted_demand,
|
|
"confidence_lower": forecast.confidence_lower,
|
|
"confidence_upper": forecast.confidence_upper,
|
|
"confidence_level": forecast.confidence_level,
|
|
"model_id": str(forecast.model_id),
|
|
"algorithm": forecast.algorithm,
|
|
"created_at": forecast.created_at.isoformat() if forecast.created_at else None
|
|
}
|
|
|
|
# Additional helper methods from original service
|
|
async def _get_latest_model_with_fallback(self, tenant_id: str, inventory_product_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get the latest trained model with fallback strategies.
|
|
|
|
CRITICAL FIX: Added timeout protection to prevent hanging during external API calls.
|
|
This ensures we don't block indefinitely if the training service is unresponsive.
|
|
"""
|
|
try:
|
|
# Add timeout protection (15 seconds) to prevent hanging
|
|
# This is shorter than the default 30s to fail fast and avoid blocking
|
|
model_data = await asyncio.wait_for(
|
|
self.model_client.get_best_model_for_forecasting(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id
|
|
),
|
|
timeout=15.0
|
|
)
|
|
|
|
if model_data:
|
|
logger.info("Found specific model for product",
|
|
inventory_product_id=inventory_product_id,
|
|
model_id=model_data.get('model_id'))
|
|
return model_data
|
|
|
|
# Fallback: Try to get any model for this tenant (also with timeout)
|
|
fallback_model = await asyncio.wait_for(
|
|
self.model_client.get_any_model_for_tenant(tenant_id),
|
|
timeout=15.0
|
|
)
|
|
|
|
if fallback_model:
|
|
logger.info("Using fallback model",
|
|
model_id=fallback_model.get('model_id'))
|
|
return fallback_model
|
|
|
|
logger.error("No models available for tenant", tenant_id=tenant_id)
|
|
return None
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.error("Timeout fetching model data from training service",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
timeout_seconds=15)
|
|
return None
|
|
except Exception as e:
|
|
logger.error("Error getting model", error=str(e), tenant_id=tenant_id)
|
|
return None
|
|
|
|
async def _prepare_forecast_features_with_fallbacks(
|
|
self,
|
|
tenant_id: str,
|
|
request: ForecastRequest
|
|
) -> Dict[str, Any]:
|
|
"""Prepare features with comprehensive fallbacks"""
|
|
# Check for school holidays using external service
|
|
is_holiday = await self._check_holiday(tenant_id, request.forecast_date)
|
|
|
|
features = {
|
|
"date": request.forecast_date.isoformat(),
|
|
"day_of_week": request.forecast_date.weekday(),
|
|
"is_weekend": request.forecast_date.weekday() >= 5,
|
|
"day_of_month": request.forecast_date.day,
|
|
"month": request.forecast_date.month,
|
|
"quarter": (request.forecast_date.month - 1) // 3 + 1,
|
|
"week_of_year": request.forecast_date.isocalendar().week,
|
|
"season": self._get_season(request.forecast_date.month),
|
|
"is_holiday": is_holiday,
|
|
# CRITICAL FIX: Add tenant_id and inventory_product_id for historical feature enrichment
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": request.inventory_product_id,
|
|
}
|
|
|
|
# Fetch REAL weather data from external service
|
|
try:
|
|
# Get weather forecast for next 7 days (covers most forecast requests)
|
|
weather_forecasts = await self.data_client.fetch_weather_forecast(
|
|
tenant_id=tenant_id,
|
|
days=7,
|
|
latitude=40.4168, # Madrid coordinates (could be parameterized per tenant)
|
|
longitude=-3.7038
|
|
)
|
|
|
|
# Find weather for the specific forecast date
|
|
forecast_date_str = request.forecast_date.isoformat().split('T')[0]
|
|
weather_for_date = None
|
|
|
|
for weather in weather_forecasts:
|
|
# Extract date from forecast_date field
|
|
weather_date = weather.get('forecast_date', '')
|
|
if isinstance(weather_date, str):
|
|
weather_date = weather_date.split('T')[0]
|
|
elif hasattr(weather_date, 'isoformat'):
|
|
weather_date = weather_date.date().isoformat()
|
|
else:
|
|
weather_date = str(weather_date).split('T')[0]
|
|
|
|
if weather_date == forecast_date_str:
|
|
weather_for_date = weather
|
|
break
|
|
|
|
if weather_for_date:
|
|
logger.info("Using REAL weather data from external service",
|
|
date=forecast_date_str,
|
|
temp=weather_for_date.get('temperature'),
|
|
precipitation=weather_for_date.get('precipitation'))
|
|
|
|
features.update({
|
|
"temperature": weather_for_date.get('temperature', 20.0),
|
|
"precipitation": weather_for_date.get('precipitation', 0.0),
|
|
"humidity": weather_for_date.get('humidity', 65.0),
|
|
"wind_speed": weather_for_date.get('wind_speed', 5.0),
|
|
"pressure": weather_for_date.get('pressure', 1013.0),
|
|
"weather_description": weather_for_date.get('description'),
|
|
})
|
|
else:
|
|
logger.warning("No weather data for specific date, using defaults",
|
|
date=forecast_date_str,
|
|
forecasts_count=len(weather_forecasts))
|
|
features.update({
|
|
"temperature": 20.0,
|
|
"precipitation": 0.0,
|
|
"humidity": 65.0,
|
|
"wind_speed": 5.0,
|
|
"pressure": 1013.0,
|
|
})
|
|
except Exception as e:
|
|
logger.error("Failed to fetch weather data, using defaults",
|
|
error=str(e),
|
|
date=request.forecast_date.isoformat())
|
|
# Fallback to defaults on error
|
|
features.update({
|
|
"temperature": 20.0,
|
|
"precipitation": 0.0,
|
|
"humidity": 65.0,
|
|
"wind_speed": 5.0,
|
|
"pressure": 1013.0,
|
|
})
|
|
|
|
# NOTE: Traffic features are NOT included in predictions
|
|
# Reason: We only have historical and real-time traffic data, not forecasts
|
|
# The model learns traffic patterns during training (using historical data)
|
|
# and applies those learned patterns via day_of_week, is_weekend, holidays
|
|
# Including fake/estimated traffic values would mislead the model
|
|
# See: TRAFFIC_DATA_ANALYSIS.md for full explanation
|
|
|
|
return features
|
|
|
|
async def _prepare_forecast_features_with_fallbacks_and_weather_map(
|
|
self,
|
|
tenant_id: str,
|
|
request: ForecastRequest,
|
|
weather_map: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""Prepare features with comprehensive fallbacks using a pre-fetched weather map"""
|
|
# Check for holidays using external service
|
|
is_holiday = await self._check_holiday(tenant_id, request.forecast_date)
|
|
|
|
features = {
|
|
"date": request.forecast_date.isoformat(),
|
|
"day_of_week": request.forecast_date.weekday(),
|
|
"is_weekend": request.forecast_date.weekday() >= 5,
|
|
"day_of_month": request.forecast_date.day,
|
|
"month": request.forecast_date.month,
|
|
"quarter": (request.forecast_date.month - 1) // 3 + 1,
|
|
"week_of_year": request.forecast_date.isocalendar().week,
|
|
"season": self._get_season(request.forecast_date.month),
|
|
"is_holiday": is_holiday,
|
|
# CRITICAL FIX: Add tenant_id and inventory_product_id for historical feature enrichment
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": request.inventory_product_id,
|
|
}
|
|
|
|
# Use the pre-fetched weather data from the weather map to avoid additional API calls
|
|
forecast_date_str = request.forecast_date.isoformat().split('T')[0]
|
|
weather_for_date = weather_map.get(forecast_date_str)
|
|
|
|
if weather_for_date:
|
|
logger.info("Using REAL weather data from external service via weather map",
|
|
date=forecast_date_str,
|
|
temp=weather_for_date.get('temperature'),
|
|
precipitation=weather_for_date.get('precipitation'))
|
|
|
|
features.update({
|
|
"temperature": weather_for_date.get('temperature', 20.0),
|
|
"precipitation": weather_for_date.get('precipitation', 0.0),
|
|
"humidity": weather_for_date.get('humidity', 65.0),
|
|
"wind_speed": weather_for_date.get('wind_speed', 5.0),
|
|
"pressure": weather_for_date.get('pressure', 1013.0),
|
|
"weather_description": weather_for_date.get('description'),
|
|
})
|
|
else:
|
|
logger.warning("No weather data for specific date in weather map, using defaults",
|
|
date=forecast_date_str)
|
|
features.update({
|
|
"temperature": 20.0,
|
|
"precipitation": 0.0,
|
|
"humidity": 65.0,
|
|
"wind_speed": 5.0,
|
|
"pressure": 1013.0,
|
|
})
|
|
|
|
# NOTE: Traffic features are NOT included in predictions
|
|
# Reason: We only have historical and real-time traffic data, not forecasts
|
|
# The model learns traffic patterns during training (using historical data)
|
|
# and applies those learned patterns via day_of_week, is_weekend, holidays
|
|
# Including fake/estimated traffic values would mislead the model
|
|
# See: TRAFFIC_DATA_ANALYSIS.md for full explanation
|
|
|
|
return features
|
|
|
|
def _get_season(self, month: int) -> int:
|
|
"""Get season from month"""
|
|
if month in [12, 1, 2]:
|
|
return 1 # Winter
|
|
elif month in [3, 4, 5]:
|
|
return 2 # Spring
|
|
elif month in [6, 7, 8]:
|
|
return 3 # Summer
|
|
else:
|
|
return 4 # Autumn
|
|
|
|
async def _check_holiday(self, tenant_id: str, date_obj: date) -> bool:
|
|
"""
|
|
Check if a date is a holiday using external service calendar
|
|
|
|
Falls back to Spanish national holidays if:
|
|
- Tenant has no calendar configured
|
|
- External service is unavailable
|
|
"""
|
|
try:
|
|
# Get tenant's calendar information
|
|
calendar_info = await self.data_client.fetch_tenant_calendar(tenant_id)
|
|
|
|
if calendar_info:
|
|
# Check school holiday via external service
|
|
is_school_holiday = await self.data_client.check_school_holiday(
|
|
calendar_id=calendar_info["calendar_id"],
|
|
check_date=date_obj.isoformat(),
|
|
tenant_id=tenant_id
|
|
)
|
|
return is_school_holiday
|
|
else:
|
|
# Fallback to Spanish national holidays
|
|
return self._is_spanish_national_holiday(date_obj)
|
|
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Holiday check failed, falling back to national holidays",
|
|
error=str(e),
|
|
tenant_id=tenant_id,
|
|
date=date_obj.isoformat()
|
|
)
|
|
return self._is_spanish_national_holiday(date_obj)
|
|
|
|
def _is_spanish_national_holiday(self, date_obj: date) -> bool:
|
|
"""Check if a date is a major Spanish national holiday (fallback)"""
|
|
month_day = (date_obj.month, date_obj.day)
|
|
spanish_holidays = [
|
|
(1, 1), (1, 6), (5, 1), (8, 15), (10, 12),
|
|
(11, 1), (12, 6), (12, 8), (12, 25)
|
|
]
|
|
return month_day in spanish_holidays
|
|
|
|
def _apply_business_rules(
|
|
self,
|
|
prediction: Dict[str, float],
|
|
request: ForecastRequest,
|
|
features: Dict[str, Any]
|
|
) -> Dict[str, float]:
|
|
"""Apply Spanish bakery business rules to predictions"""
|
|
base_prediction = prediction["prediction"]
|
|
|
|
# Ensure confidence bounds exist with fallbacks
|
|
lower_bound = prediction.get("lower_bound", base_prediction * 0.8)
|
|
upper_bound = prediction.get("upper_bound", base_prediction * 1.2)
|
|
|
|
# Apply adjustment factors
|
|
adjustment_factor = 1.0
|
|
|
|
if features.get("is_weekend", False):
|
|
adjustment_factor *= 0.8
|
|
|
|
if features.get("is_holiday", False):
|
|
adjustment_factor *= 0.5
|
|
|
|
# Weather adjustments
|
|
precipitation = features.get("precipitation", 0.0)
|
|
if precipitation > 2.0:
|
|
adjustment_factor *= 0.7
|
|
|
|
# Apply adjustments to prediction
|
|
adjusted_prediction = max(0, base_prediction * adjustment_factor)
|
|
|
|
# For confidence bounds, preserve relative interval width while respecting minimum bounds
|
|
original_interval = upper_bound - lower_bound
|
|
adjusted_interval = original_interval * adjustment_factor
|
|
|
|
# Ensure minimum reasonable lower bound (at least 20% of prediction or 5, whichever is larger)
|
|
min_lower_bound = max(adjusted_prediction * 0.2, 5.0)
|
|
adjusted_lower = max(min_lower_bound, adjusted_prediction - (adjusted_interval / 2))
|
|
adjusted_upper = max(adjusted_lower + 10, adjusted_prediction + (adjusted_interval / 2))
|
|
|
|
return {
|
|
"prediction": adjusted_prediction,
|
|
"lower_bound": adjusted_lower,
|
|
"upper_bound": adjusted_upper,
|
|
"confidence_interval": adjusted_upper - adjusted_lower,
|
|
"confidence_level": prediction.get("confidence_level", 0.8),
|
|
"adjustment_factor": adjustment_factor
|
|
}
|
|
|
|
|
|
# Legacy compatibility alias
|
|
ForecastingService = EnhancedForecastingService
|