243 lines
7.8 KiB
Python
243 lines
7.8 KiB
Python
|
|
"""
|
||
|
|
Model Selection System
|
||
|
|
Determines whether to use Prophet-only or Hybrid Prophet+XGBoost models
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pandas as pd
|
||
|
|
import numpy as np
|
||
|
|
from typing import Dict, Any, Optional
|
||
|
|
import structlog
|
||
|
|
|
||
|
|
logger = structlog.get_logger()
|
||
|
|
|
||
|
|
|
||
|
|
class ModelSelector:
|
||
|
|
"""
|
||
|
|
Intelligent model selection based on data characteristics.
|
||
|
|
|
||
|
|
Decision Criteria:
|
||
|
|
- Data size: Hybrid needs more data (min 90 days)
|
||
|
|
- Complexity: High variance benefits from XGBoost
|
||
|
|
- Seasonality strength: Weak seasonality benefits from XGBoost
|
||
|
|
- Historical performance: Compare models on validation set
|
||
|
|
"""
|
||
|
|
|
||
|
|
# Thresholds for model selection
|
||
|
|
MIN_DATA_POINTS_HYBRID = 90 # Minimum data points for hybrid
|
||
|
|
HIGH_VARIANCE_THRESHOLD = 0.5 # CV > 0.5 suggests complex patterns
|
||
|
|
LOW_SEASONALITY_THRESHOLD = 0.3 # Weak seasonal patterns
|
||
|
|
HYBRID_IMPROVEMENT_THRESHOLD = 0.05 # 5% MAPE improvement to justify hybrid
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def select_model_type(
|
||
|
|
self,
|
||
|
|
df: pd.DataFrame,
|
||
|
|
product_category: str = "unknown",
|
||
|
|
force_prophet: bool = False,
|
||
|
|
force_hybrid: bool = False
|
||
|
|
) -> str:
|
||
|
|
"""
|
||
|
|
Select best model type based on data characteristics.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
df: Training data with 'y' column
|
||
|
|
product_category: Product category (bread, pastries, etc.)
|
||
|
|
force_prophet: Force Prophet-only model
|
||
|
|
force_hybrid: Force hybrid model
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
"prophet" or "hybrid"
|
||
|
|
"""
|
||
|
|
# Honor forced selections
|
||
|
|
if force_prophet:
|
||
|
|
logger.info("Prophet-only model forced by configuration")
|
||
|
|
return "prophet"
|
||
|
|
|
||
|
|
if force_hybrid:
|
||
|
|
logger.info("Hybrid model forced by configuration")
|
||
|
|
return "hybrid"
|
||
|
|
|
||
|
|
# Check minimum data requirements
|
||
|
|
if len(df) < self.MIN_DATA_POINTS_HYBRID:
|
||
|
|
logger.info(
|
||
|
|
"Insufficient data for hybrid model, using Prophet",
|
||
|
|
data_points=len(df),
|
||
|
|
min_required=self.MIN_DATA_POINTS_HYBRID
|
||
|
|
)
|
||
|
|
return "prophet"
|
||
|
|
|
||
|
|
# Calculate data characteristics
|
||
|
|
characteristics = self._analyze_data_characteristics(df)
|
||
|
|
|
||
|
|
# Decision logic
|
||
|
|
score_hybrid = 0
|
||
|
|
score_prophet = 0
|
||
|
|
|
||
|
|
# Factor 1: Data complexity (variance)
|
||
|
|
if characteristics['coefficient_of_variation'] > self.HIGH_VARIANCE_THRESHOLD:
|
||
|
|
score_hybrid += 2
|
||
|
|
logger.debug("High variance detected, favoring hybrid", cv=characteristics['coefficient_of_variation'])
|
||
|
|
else:
|
||
|
|
score_prophet += 1
|
||
|
|
|
||
|
|
# Factor 2: Seasonality strength
|
||
|
|
if characteristics['seasonality_strength'] < self.LOW_SEASONALITY_THRESHOLD:
|
||
|
|
score_hybrid += 2
|
||
|
|
logger.debug("Weak seasonality detected, favoring hybrid", strength=characteristics['seasonality_strength'])
|
||
|
|
else:
|
||
|
|
score_prophet += 1
|
||
|
|
|
||
|
|
# Factor 3: Data size (more data = better for hybrid)
|
||
|
|
if len(df) > 180:
|
||
|
|
score_hybrid += 1
|
||
|
|
elif len(df) < 120:
|
||
|
|
score_prophet += 1
|
||
|
|
|
||
|
|
# Factor 4: Product category considerations
|
||
|
|
if product_category in ['seasonal', 'cakes']:
|
||
|
|
# Event-driven products benefit from XGBoost pattern learning
|
||
|
|
score_hybrid += 1
|
||
|
|
elif product_category in ['bread', 'savory']:
|
||
|
|
# Stable products work well with Prophet
|
||
|
|
score_prophet += 1
|
||
|
|
|
||
|
|
# Factor 5: Zero ratio (sparse data)
|
||
|
|
if characteristics['zero_ratio'] > 0.3:
|
||
|
|
# High zero ratio suggests difficult forecasting, hybrid might help
|
||
|
|
score_hybrid += 1
|
||
|
|
|
||
|
|
# Make decision
|
||
|
|
selected_model = "hybrid" if score_hybrid > score_prophet else "prophet"
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
"Model selection complete",
|
||
|
|
selected_model=selected_model,
|
||
|
|
score_hybrid=score_hybrid,
|
||
|
|
score_prophet=score_prophet,
|
||
|
|
data_points=len(df),
|
||
|
|
cv=characteristics['coefficient_of_variation'],
|
||
|
|
seasonality=characteristics['seasonality_strength'],
|
||
|
|
category=product_category
|
||
|
|
)
|
||
|
|
|
||
|
|
return selected_model
|
||
|
|
|
||
|
|
def _analyze_data_characteristics(self, df: pd.DataFrame) -> Dict[str, float]:
|
||
|
|
"""
|
||
|
|
Analyze time series characteristics.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
df: DataFrame with 'y' column (sales data)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary with data characteristics
|
||
|
|
"""
|
||
|
|
y = df['y'].values
|
||
|
|
|
||
|
|
# Coefficient of variation
|
||
|
|
cv = np.std(y) / np.mean(y) if np.mean(y) > 0 else 0
|
||
|
|
|
||
|
|
# Zero ratio
|
||
|
|
zero_ratio = (y == 0).sum() / len(y)
|
||
|
|
|
||
|
|
# Seasonality strength (simple proxy using rolling std)
|
||
|
|
if len(df) >= 14:
|
||
|
|
rolling_mean = pd.Series(y).rolling(window=7, center=True).mean()
|
||
|
|
seasonality_strength = rolling_mean.std() / (np.std(y) + 1e-6) if np.std(y) > 0 else 0
|
||
|
|
else:
|
||
|
|
seasonality_strength = 0.5 # Default
|
||
|
|
|
||
|
|
# Trend strength
|
||
|
|
if len(df) >= 30:
|
||
|
|
from scipy import stats
|
||
|
|
x = np.arange(len(y))
|
||
|
|
slope, _, r_value, _, _ = stats.linregress(x, y)
|
||
|
|
trend_strength = abs(r_value)
|
||
|
|
else:
|
||
|
|
trend_strength = 0
|
||
|
|
|
||
|
|
return {
|
||
|
|
'coefficient_of_variation': float(cv),
|
||
|
|
'zero_ratio': float(zero_ratio),
|
||
|
|
'seasonality_strength': float(seasonality_strength),
|
||
|
|
'trend_strength': float(trend_strength),
|
||
|
|
'mean': float(np.mean(y)),
|
||
|
|
'std': float(np.std(y))
|
||
|
|
}
|
||
|
|
|
||
|
|
def compare_models(
|
||
|
|
self,
|
||
|
|
prophet_metrics: Dict[str, float],
|
||
|
|
hybrid_metrics: Dict[str, float]
|
||
|
|
) -> str:
|
||
|
|
"""
|
||
|
|
Compare Prophet and Hybrid model performance.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
prophet_metrics: Prophet model metrics (with 'mape' key)
|
||
|
|
hybrid_metrics: Hybrid model metrics (with 'mape' key)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
"prophet" or "hybrid" based on better performance
|
||
|
|
"""
|
||
|
|
prophet_mape = prophet_metrics.get('mape', float('inf'))
|
||
|
|
hybrid_mape = hybrid_metrics.get('mape', float('inf'))
|
||
|
|
|
||
|
|
# Calculate improvement
|
||
|
|
if prophet_mape > 0:
|
||
|
|
improvement = (prophet_mape - hybrid_mape) / prophet_mape
|
||
|
|
else:
|
||
|
|
improvement = 0
|
||
|
|
|
||
|
|
# Hybrid must improve by at least threshold to justify complexity
|
||
|
|
if improvement >= self.HYBRID_IMPROVEMENT_THRESHOLD:
|
||
|
|
logger.info(
|
||
|
|
"Hybrid model selected based on performance",
|
||
|
|
prophet_mape=prophet_mape,
|
||
|
|
hybrid_mape=hybrid_mape,
|
||
|
|
improvement=f"{improvement*100:.1f}%"
|
||
|
|
)
|
||
|
|
return "hybrid"
|
||
|
|
else:
|
||
|
|
logger.info(
|
||
|
|
"Prophet model selected (hybrid improvement insufficient)",
|
||
|
|
prophet_mape=prophet_mape,
|
||
|
|
hybrid_mape=hybrid_mape,
|
||
|
|
improvement=f"{improvement*100:.1f}%"
|
||
|
|
)
|
||
|
|
return "prophet"
|
||
|
|
|
||
|
|
|
||
|
|
def should_use_hybrid_model(
|
||
|
|
df: pd.DataFrame,
|
||
|
|
product_category: str = "unknown",
|
||
|
|
tenant_settings: Dict[str, Any] = None
|
||
|
|
) -> bool:
|
||
|
|
"""
|
||
|
|
Convenience function to determine if hybrid model should be used.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
df: Training data
|
||
|
|
product_category: Product category
|
||
|
|
tenant_settings: Optional tenant-specific settings
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if hybrid model should be used, False otherwise
|
||
|
|
"""
|
||
|
|
selector = ModelSelector()
|
||
|
|
|
||
|
|
# Check tenant settings
|
||
|
|
force_prophet = tenant_settings.get('force_prophet_only', False) if tenant_settings else False
|
||
|
|
force_hybrid = tenant_settings.get('force_hybrid', False) if tenant_settings else False
|
||
|
|
|
||
|
|
selected = selector.select_model_type(
|
||
|
|
df=df,
|
||
|
|
product_category=product_category,
|
||
|
|
force_prophet=force_prophet,
|
||
|
|
force_hybrid=force_hybrid
|
||
|
|
)
|
||
|
|
|
||
|
|
return selected == "hybrid"
|