665 lines
29 KiB
Python
665 lines
29 KiB
Python
"""
|
|
Enhanced Forecasting Service with Repository Pattern
|
|
Main forecasting service that uses the repository pattern for data access
|
|
"""
|
|
|
|
import structlog
|
|
from typing import Dict, List, Any, Optional
|
|
from datetime import datetime, date, timedelta
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.ml.predictor import BakeryForecaster
|
|
from app.schemas.forecasts import ForecastRequest, ForecastResponse
|
|
from app.services.prediction_service import PredictionService
|
|
from app.services.model_client import ModelClient
|
|
from app.services.data_client import DataClient
|
|
|
|
# Import repositories
|
|
from app.repositories import (
|
|
ForecastRepository,
|
|
PredictionBatchRepository,
|
|
ForecastAlertRepository,
|
|
PerformanceMetricRepository,
|
|
PredictionCacheRepository
|
|
)
|
|
|
|
# Import shared database components
|
|
from shared.database.base import create_database_manager
|
|
from shared.database.unit_of_work import UnitOfWork
|
|
from shared.database.transactions import transactional
|
|
from shared.database.exceptions import DatabaseError
|
|
from app.core.config import settings
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class EnhancedForecastingService:
|
|
"""
|
|
Enhanced forecasting service using repository pattern.
|
|
Handles forecast generation, batch processing, and alerting with proper data abstraction.
|
|
"""
|
|
|
|
def __init__(self, database_manager=None):
|
|
self.database_manager = database_manager or create_database_manager(
|
|
settings.DATABASE_URL, "forecasting-service"
|
|
)
|
|
|
|
# Initialize ML components
|
|
self.forecaster = BakeryForecaster(database_manager=self.database_manager)
|
|
self.prediction_service = PredictionService(database_manager=self.database_manager)
|
|
self.model_client = ModelClient(database_manager=self.database_manager)
|
|
self.data_client = DataClient()
|
|
|
|
async def _init_repositories(self, session):
|
|
"""Initialize repositories with session"""
|
|
return {
|
|
'forecast': ForecastRepository(session),
|
|
'batch': PredictionBatchRepository(session),
|
|
'alert': ForecastAlertRepository(session),
|
|
'performance': PerformanceMetricRepository(session),
|
|
'cache': PredictionCacheRepository(session)
|
|
}
|
|
|
|
async def generate_batch_forecasts(self, tenant_id: str, request) -> Dict[str, Any]:
|
|
"""Generate batch forecasts using repository pattern"""
|
|
try:
|
|
# Implementation would use repository pattern to generate multiple forecasts
|
|
return {
|
|
"batch_id": f"batch_{tenant_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
|
"tenant_id": tenant_id,
|
|
"forecasts": [],
|
|
"total_forecasts": 0,
|
|
"successful_forecasts": 0,
|
|
"failed_forecasts": 0,
|
|
"enhanced_features": True,
|
|
"repository_integration": True
|
|
}
|
|
except Exception as e:
|
|
logger.error("Batch forecast generation failed", error=str(e))
|
|
raise
|
|
|
|
async def get_tenant_forecasts(self, tenant_id: str, inventory_product_id: str = None,
|
|
start_date: date = None, end_date: date = None,
|
|
skip: int = 0, limit: int = 100) -> List[Dict]:
|
|
"""Get tenant forecasts with filtering"""
|
|
try:
|
|
# Implementation would use repository pattern to fetch forecasts
|
|
return []
|
|
except Exception as e:
|
|
logger.error("Failed to get tenant forecasts", error=str(e))
|
|
raise
|
|
|
|
async def get_forecast_by_id(self, forecast_id: str) -> Optional[Dict]:
|
|
"""Get forecast by ID"""
|
|
try:
|
|
# Implementation would use repository pattern
|
|
return None
|
|
except Exception as e:
|
|
logger.error("Failed to get forecast by ID", error=str(e))
|
|
raise
|
|
|
|
async def delete_forecast(self, forecast_id: str) -> bool:
|
|
"""Delete forecast"""
|
|
try:
|
|
# Implementation would use repository pattern
|
|
return True
|
|
except Exception as e:
|
|
logger.error("Failed to delete forecast", error=str(e))
|
|
return False
|
|
|
|
async def get_tenant_alerts(self, tenant_id: str, active_only: bool = True,
|
|
skip: int = 0, limit: int = 50) -> List[Dict]:
|
|
"""Get tenant alerts"""
|
|
try:
|
|
# Implementation would use repository pattern
|
|
return []
|
|
except Exception as e:
|
|
logger.error("Failed to get tenant alerts", error=str(e))
|
|
raise
|
|
|
|
async def get_tenant_forecast_statistics(self, tenant_id: str) -> Dict[str, Any]:
|
|
"""Get tenant forecast statistics"""
|
|
try:
|
|
# Implementation would use repository pattern
|
|
return {
|
|
"total_forecasts": 0,
|
|
"active_forecasts": 0,
|
|
"recent_forecasts": 0,
|
|
"accuracy_metrics": {},
|
|
"enhanced_features": True
|
|
}
|
|
except Exception as e:
|
|
logger.error("Failed to get forecast statistics", error=str(e))
|
|
return {"error": str(e)}
|
|
|
|
async def generate_batch_predictions(self, tenant_id: str, batch_request: Dict) -> Dict[str, Any]:
|
|
"""Generate batch predictions"""
|
|
try:
|
|
# Implementation would use repository pattern
|
|
return {
|
|
"batch_id": f"pred_batch_{tenant_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
|
"tenant_id": tenant_id,
|
|
"predictions": [],
|
|
"total_predictions": 0,
|
|
"successful_predictions": 0,
|
|
"failed_predictions": 0,
|
|
"enhanced_features": True
|
|
}
|
|
except Exception as e:
|
|
logger.error("Batch predictions failed", error=str(e))
|
|
raise
|
|
|
|
async def get_cached_predictions(self, tenant_id: str, inventory_product_id: str = None,
|
|
skip: int = 0, limit: int = 100) -> List[Dict]:
|
|
"""Get cached predictions"""
|
|
try:
|
|
# Implementation would use repository pattern
|
|
return []
|
|
except Exception as e:
|
|
logger.error("Failed to get cached predictions", error=str(e))
|
|
raise
|
|
|
|
async def clear_prediction_cache(self, tenant_id: str, inventory_product_id: str = None) -> int:
|
|
"""Clear prediction cache"""
|
|
try:
|
|
# Implementation would use repository pattern
|
|
return 0
|
|
except Exception as e:
|
|
logger.error("Failed to clear prediction cache", error=str(e))
|
|
return 0
|
|
|
|
async def get_prediction_performance(self, tenant_id: str, model_id: str = None,
|
|
start_date: date = None, end_date: date = None) -> Dict[str, Any]:
|
|
"""Get prediction performance metrics"""
|
|
try:
|
|
# Implementation would use repository pattern
|
|
return {
|
|
"accuracy_metrics": {},
|
|
"performance_trends": [],
|
|
"enhanced_features": True
|
|
}
|
|
except Exception as e:
|
|
logger.error("Failed to get prediction performance", error=str(e))
|
|
raise
|
|
|
|
async def generate_forecast(
|
|
self,
|
|
tenant_id: str,
|
|
request: ForecastRequest
|
|
) -> ForecastResponse:
|
|
"""
|
|
Generate forecast using repository pattern with caching and alerting.
|
|
"""
|
|
start_time = datetime.utcnow()
|
|
|
|
try:
|
|
logger.info("Generating enhanced forecast",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
date=request.forecast_date.isoformat())
|
|
|
|
# Get session and initialize repositories
|
|
async with self.database_manager.get_session() as session:
|
|
repos = await self._init_repositories(session)
|
|
|
|
# Step 1: Check cache first
|
|
cached_prediction = await repos['cache'].get_cached_prediction(
|
|
tenant_id, request.inventory_product_id, request.location, request.forecast_date
|
|
)
|
|
|
|
if cached_prediction:
|
|
logger.debug("Using cached prediction",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id)
|
|
return self._create_forecast_response_from_cache(cached_prediction)
|
|
|
|
# Step 2: Get model with validation
|
|
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
|
|
|
|
if not model_data:
|
|
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
|
|
|
|
# Step 3: Prepare features with fallbacks
|
|
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
|
|
|
|
# Step 4: Generate prediction
|
|
prediction_result = await self.prediction_service.predict(
|
|
model_id=model_data['model_id'],
|
|
model_path=model_data['model_path'],
|
|
features=features,
|
|
confidence_level=request.confidence_level
|
|
)
|
|
|
|
# Step 5: Apply business rules
|
|
adjusted_prediction = self._apply_business_rules(
|
|
prediction_result, request, features
|
|
)
|
|
|
|
# Step 6: Save forecast using repository
|
|
# Convert forecast_date to datetime if it's a string
|
|
forecast_datetime = request.forecast_date
|
|
if isinstance(forecast_datetime, str):
|
|
from dateutil.parser import parse
|
|
forecast_datetime = parse(forecast_datetime)
|
|
|
|
forecast_data = {
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": request.inventory_product_id,
|
|
"location": request.location,
|
|
"forecast_date": forecast_datetime,
|
|
"predicted_demand": adjusted_prediction['prediction'],
|
|
"confidence_lower": adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
|
|
"confidence_upper": adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
|
|
"confidence_level": request.confidence_level,
|
|
"model_id": model_data['model_id'],
|
|
"model_version": model_data.get('version', '1.0'),
|
|
"algorithm": model_data.get('algorithm', 'prophet'),
|
|
"business_type": features.get('business_type', 'individual'),
|
|
"is_holiday": features.get('is_holiday', False),
|
|
"is_weekend": features.get('is_weekend', False),
|
|
"day_of_week": features.get('day_of_week', 0),
|
|
"weather_temperature": features.get('temperature'),
|
|
"weather_precipitation": features.get('precipitation'),
|
|
"weather_description": features.get('weather_description'),
|
|
"traffic_volume": features.get('traffic_volume'),
|
|
"processing_time_ms": int((datetime.utcnow() - start_time).total_seconds() * 1000),
|
|
"features_used": features
|
|
}
|
|
|
|
forecast = await repos['forecast'].create_forecast(forecast_data)
|
|
|
|
# Step 7: Cache the prediction
|
|
await repos['cache'].cache_prediction(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
location=request.location,
|
|
forecast_date=forecast_datetime,
|
|
predicted_demand=adjusted_prediction['prediction'],
|
|
confidence_lower=adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
|
|
confidence_upper=adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
|
|
model_id=model_data['model_id'],
|
|
expires_in_hours=24
|
|
)
|
|
|
|
# Step 8: Check for alerts
|
|
await self._check_and_create_alerts(forecast, adjusted_prediction, repos)
|
|
|
|
logger.info("Enhanced forecast generated successfully",
|
|
forecast_id=forecast.id,
|
|
tenant_id=tenant_id,
|
|
prediction=adjusted_prediction['prediction'])
|
|
|
|
return self._create_forecast_response_from_model(forecast)
|
|
|
|
except Exception as e:
|
|
processing_time = int((datetime.utcnow() - start_time).total_seconds() * 1000)
|
|
logger.error("Error generating enhanced forecast",
|
|
error=str(e),
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=request.inventory_product_id,
|
|
processing_time=processing_time)
|
|
raise
|
|
|
|
async def get_forecast_history(
|
|
self,
|
|
tenant_id: str,
|
|
inventory_product_id: Optional[str] = None,
|
|
start_date: Optional[date] = None,
|
|
end_date: Optional[date] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get forecast history using repository"""
|
|
try:
|
|
async with self.database_manager.get_session() as session:
|
|
repos = await self._init_repositories(session)
|
|
|
|
if start_date and end_date:
|
|
forecasts = await repos['forecast'].get_forecasts_by_date_range(
|
|
tenant_id, start_date, end_date, inventory_product_id
|
|
)
|
|
else:
|
|
# Get recent forecasts (last 30 days)
|
|
forecasts = await repos['forecast'].get_recent_records(
|
|
tenant_id, hours=24*30
|
|
)
|
|
|
|
# Convert to dict format
|
|
return [self._forecast_to_dict(forecast) for forecast in forecasts]
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get forecast history",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
return []
|
|
|
|
async def get_forecast_analytics(self, tenant_id: str) -> Dict[str, Any]:
|
|
"""Get comprehensive forecast analytics using repositories"""
|
|
try:
|
|
async with self.database_manager.get_session() as session:
|
|
repos = await self._init_repositories(session)
|
|
|
|
# Get forecast summary
|
|
forecast_summary = await repos['forecast'].get_forecast_summary(tenant_id)
|
|
|
|
# Get alert statistics
|
|
alert_stats = await repos['alert'].get_alert_statistics(tenant_id)
|
|
|
|
# Get batch statistics
|
|
batch_stats = await repos['batch'].get_batch_statistics(tenant_id)
|
|
|
|
# Get cache performance
|
|
cache_stats = await repos['cache'].get_cache_statistics(tenant_id)
|
|
|
|
# Get performance trends
|
|
performance_trends = await repos['performance'].get_performance_trends(
|
|
tenant_id, days=30
|
|
)
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"forecast_analytics": forecast_summary,
|
|
"alert_analytics": alert_stats,
|
|
"batch_analytics": batch_stats,
|
|
"cache_performance": cache_stats,
|
|
"performance_trends": performance_trends,
|
|
"generated_at": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get forecast analytics",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
return {"error": f"Failed to get analytics: {str(e)}"}
|
|
|
|
async def create_batch_prediction(
|
|
self,
|
|
tenant_id: str,
|
|
batch_name: str,
|
|
inventory_product_ids: List[str],
|
|
forecast_days: int = 7
|
|
) -> Dict[str, Any]:
|
|
"""Create batch prediction job using repository"""
|
|
try:
|
|
async with self.database_manager.get_session() as session:
|
|
repos = await self._init_repositories(session)
|
|
|
|
# Create batch record
|
|
batch_data = {
|
|
"tenant_id": tenant_id,
|
|
"batch_name": batch_name,
|
|
"total_products": len(inventory_product_ids),
|
|
"forecast_days": forecast_days,
|
|
"status": "pending"
|
|
}
|
|
|
|
batch = await repos['batch'].create_batch(batch_data)
|
|
|
|
logger.info("Batch prediction created",
|
|
batch_id=batch.id,
|
|
tenant_id=tenant_id,
|
|
total_products=len(inventory_product_ids))
|
|
|
|
return {
|
|
"batch_id": str(batch.id),
|
|
"status": batch.status,
|
|
"total_products": len(inventory_product_ids),
|
|
"created_at": batch.requested_at.isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to create batch prediction",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to create batch: {str(e)}")
|
|
|
|
async def _check_and_create_alerts(self, forecast, prediction: Dict[str, Any], repos: Dict):
|
|
"""Check forecast results and create alerts if necessary"""
|
|
try:
|
|
alerts_to_create = []
|
|
|
|
# Check for high demand alert
|
|
if prediction['prediction'] > 100: # Threshold for high demand
|
|
alerts_to_create.append({
|
|
"tenant_id": str(forecast.tenant_id),
|
|
"forecast_id": str(forecast.id), # Convert UUID to string
|
|
"alert_type": "high_demand",
|
|
"severity": "high" if prediction['prediction'] > 200 else "medium",
|
|
"message": f"High demand predicted for inventory product {str(forecast.inventory_product_id)}: {prediction['prediction']:.1f} units"
|
|
})
|
|
|
|
# Check for low demand alert
|
|
elif prediction['prediction'] < 10: # Threshold for low demand
|
|
alerts_to_create.append({
|
|
"tenant_id": str(forecast.tenant_id),
|
|
"forecast_id": str(forecast.id), # Convert UUID to string
|
|
"alert_type": "low_demand",
|
|
"severity": "low",
|
|
"message": f"Low demand predicted for inventory product {str(forecast.inventory_product_id)}: {prediction['prediction']:.1f} units"
|
|
})
|
|
|
|
# Check for stockout risk (very low prediction with narrow confidence interval)
|
|
confidence_interval = prediction['upper_bound'] - prediction['lower_bound']
|
|
if prediction['prediction'] < 5 and confidence_interval < 10:
|
|
alerts_to_create.append({
|
|
"tenant_id": str(forecast.tenant_id),
|
|
"forecast_id": str(forecast.id), # Convert UUID to string
|
|
"alert_type": "stockout_risk",
|
|
"severity": "critical",
|
|
"message": f"Stockout risk for inventory product {str(forecast.inventory_product_id)}: predicted {prediction['prediction']:.1f} units with high confidence"
|
|
})
|
|
|
|
# Create alerts
|
|
for alert_data in alerts_to_create:
|
|
await repos['alert'].create_alert(alert_data)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to create alerts",
|
|
forecast_id=forecast.id,
|
|
error=str(e))
|
|
# Don't raise - alerts are not critical for forecast generation
|
|
|
|
def _create_forecast_response_from_cache(self, cache_entry) -> ForecastResponse:
|
|
"""Create forecast response from cached entry"""
|
|
return ForecastResponse(
|
|
id=str(cache_entry.id),
|
|
tenant_id=str(cache_entry.tenant_id),
|
|
inventory_product_id=str(cache_entry.inventory_product_id), # Convert UUID to string
|
|
location=cache_entry.location,
|
|
forecast_date=cache_entry.forecast_date,
|
|
predicted_demand=cache_entry.predicted_demand,
|
|
confidence_lower=cache_entry.confidence_lower,
|
|
confidence_upper=cache_entry.confidence_upper,
|
|
confidence_level=0.8, # Default
|
|
model_id=str(cache_entry.model_id),
|
|
model_version="cached",
|
|
algorithm="cached",
|
|
business_type="individual",
|
|
is_holiday=False,
|
|
is_weekend=cache_entry.forecast_date.weekday() >= 5,
|
|
day_of_week=cache_entry.forecast_date.weekday(),
|
|
created_at=cache_entry.created_at,
|
|
processing_time_ms=0, # From cache
|
|
features_used={}
|
|
)
|
|
|
|
def _create_forecast_response_from_model(self, forecast) -> ForecastResponse:
|
|
"""Create forecast response from forecast model"""
|
|
return ForecastResponse(
|
|
id=str(forecast.id),
|
|
tenant_id=str(forecast.tenant_id),
|
|
inventory_product_id=str(forecast.inventory_product_id), # Convert UUID to string
|
|
location=forecast.location,
|
|
forecast_date=forecast.forecast_date,
|
|
predicted_demand=forecast.predicted_demand,
|
|
confidence_lower=forecast.confidence_lower,
|
|
confidence_upper=forecast.confidence_upper,
|
|
confidence_level=forecast.confidence_level,
|
|
model_id=str(forecast.model_id),
|
|
model_version=forecast.model_version,
|
|
algorithm=forecast.algorithm,
|
|
business_type=forecast.business_type,
|
|
is_holiday=forecast.is_holiday,
|
|
is_weekend=forecast.is_weekend,
|
|
day_of_week=forecast.day_of_week,
|
|
weather_temperature=forecast.weather_temperature,
|
|
weather_precipitation=forecast.weather_precipitation,
|
|
weather_description=forecast.weather_description,
|
|
traffic_volume=forecast.traffic_volume,
|
|
created_at=forecast.created_at,
|
|
processing_time_ms=forecast.processing_time_ms,
|
|
features_used=forecast.features_used
|
|
)
|
|
|
|
def _forecast_to_dict(self, forecast) -> Dict[str, Any]:
|
|
"""Convert forecast model to dictionary"""
|
|
return {
|
|
"id": str(forecast.id),
|
|
"tenant_id": str(forecast.tenant_id),
|
|
"inventory_product_id": str(forecast.inventory_product_id), # Convert UUID to string
|
|
"location": forecast.location,
|
|
"forecast_date": forecast.forecast_date.isoformat(),
|
|
"predicted_demand": forecast.predicted_demand,
|
|
"confidence_lower": forecast.confidence_lower,
|
|
"confidence_upper": forecast.confidence_upper,
|
|
"confidence_level": forecast.confidence_level,
|
|
"model_id": str(forecast.model_id),
|
|
"algorithm": forecast.algorithm,
|
|
"created_at": forecast.created_at.isoformat() if forecast.created_at else None
|
|
}
|
|
|
|
# Additional helper methods from original service
|
|
async def _get_latest_model_with_fallback(self, tenant_id: str, inventory_product_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get the latest trained model with fallback strategies"""
|
|
try:
|
|
model_data = await self.model_client.get_best_model_for_forecasting(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id
|
|
)
|
|
|
|
if model_data:
|
|
logger.info("Found specific model for product",
|
|
inventory_product_id=inventory_product_id,
|
|
model_id=model_data.get('model_id'))
|
|
return model_data
|
|
|
|
# Fallback: Try to get any model for this tenant
|
|
fallback_model = await self.model_client.get_any_model_for_tenant(tenant_id)
|
|
|
|
if fallback_model:
|
|
logger.info("Using fallback model",
|
|
model_id=fallback_model.get('model_id'))
|
|
return fallback_model
|
|
|
|
logger.error("No models available for tenant", tenant_id=tenant_id)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error("Error getting model", error=str(e))
|
|
return None
|
|
|
|
async def _prepare_forecast_features_with_fallbacks(
|
|
self,
|
|
tenant_id: str,
|
|
request: ForecastRequest
|
|
) -> Dict[str, Any]:
|
|
"""Prepare features with comprehensive fallbacks"""
|
|
features = {
|
|
"date": request.forecast_date.isoformat(),
|
|
"day_of_week": request.forecast_date.weekday(),
|
|
"is_weekend": request.forecast_date.weekday() >= 5,
|
|
"day_of_month": request.forecast_date.day,
|
|
"month": request.forecast_date.month,
|
|
"quarter": (request.forecast_date.month - 1) // 3 + 1,
|
|
"week_of_year": request.forecast_date.isocalendar().week,
|
|
"season": self._get_season(request.forecast_date.month),
|
|
"is_holiday": self._is_spanish_holiday(request.forecast_date),
|
|
}
|
|
|
|
# Add weather features (simplified)
|
|
features.update({
|
|
"temperature": 20.0, # Default values
|
|
"precipitation": 0.0,
|
|
"humidity": 65.0,
|
|
"wind_speed": 5.0,
|
|
"pressure": 1013.0,
|
|
})
|
|
|
|
# Add traffic features (simplified)
|
|
weekend_factor = 0.7 if features["is_weekend"] else 1.0
|
|
features.update({
|
|
"traffic_volume": int(100 * weekend_factor),
|
|
"pedestrian_count": int(50 * weekend_factor),
|
|
})
|
|
|
|
return features
|
|
|
|
def _get_season(self, month: int) -> int:
|
|
"""Get season from month"""
|
|
if month in [12, 1, 2]:
|
|
return 1 # Winter
|
|
elif month in [3, 4, 5]:
|
|
return 2 # Spring
|
|
elif month in [6, 7, 8]:
|
|
return 3 # Summer
|
|
else:
|
|
return 4 # Autumn
|
|
|
|
def _is_spanish_holiday(self, date: datetime) -> bool:
|
|
"""Check if a date is a major Spanish holiday"""
|
|
month_day = (date.month, date.day)
|
|
spanish_holidays = [
|
|
(1, 1), (1, 6), (5, 1), (8, 15), (10, 12),
|
|
(11, 1), (12, 6), (12, 8), (12, 25)
|
|
]
|
|
return month_day in spanish_holidays
|
|
|
|
def _apply_business_rules(
|
|
self,
|
|
prediction: Dict[str, float],
|
|
request: ForecastRequest,
|
|
features: Dict[str, Any]
|
|
) -> Dict[str, float]:
|
|
"""Apply Spanish bakery business rules to predictions"""
|
|
base_prediction = prediction["prediction"]
|
|
|
|
# Ensure confidence bounds exist with fallbacks
|
|
lower_bound = prediction.get("lower_bound", base_prediction * 0.8)
|
|
upper_bound = prediction.get("upper_bound", base_prediction * 1.2)
|
|
|
|
# Apply adjustment factors
|
|
adjustment_factor = 1.0
|
|
|
|
if features.get("is_weekend", False):
|
|
adjustment_factor *= 0.8
|
|
|
|
if features.get("is_holiday", False):
|
|
adjustment_factor *= 0.5
|
|
|
|
# Weather adjustments
|
|
precipitation = features.get("precipitation", 0.0)
|
|
if precipitation > 2.0:
|
|
adjustment_factor *= 0.7
|
|
|
|
# Apply adjustments to prediction
|
|
adjusted_prediction = max(0, base_prediction * adjustment_factor)
|
|
|
|
# For confidence bounds, preserve relative interval width while respecting minimum bounds
|
|
original_interval = upper_bound - lower_bound
|
|
adjusted_interval = original_interval * adjustment_factor
|
|
|
|
# Ensure minimum reasonable lower bound (at least 20% of prediction or 5, whichever is larger)
|
|
min_lower_bound = max(adjusted_prediction * 0.2, 5.0)
|
|
adjusted_lower = max(min_lower_bound, adjusted_prediction - (adjusted_interval / 2))
|
|
adjusted_upper = max(adjusted_lower + 10, adjusted_prediction + (adjusted_interval / 2))
|
|
|
|
return {
|
|
"prediction": adjusted_prediction,
|
|
"lower_bound": adjusted_lower,
|
|
"upper_bound": adjusted_upper,
|
|
"confidence_interval": adjusted_upper - adjusted_lower,
|
|
"confidence_level": prediction.get("confidence_level", 0.8),
|
|
"adjustment_factor": adjustment_factor
|
|
}
|
|
|
|
|
|
# Legacy compatibility alias
|
|
ForecastingService = EnhancedForecastingService |