Files
bakery-ia/services/training/app/ml/model_selector.py

258 lines
8.5 KiB
Python

"""
Model Selection System
Determines whether to use Prophet-only or Hybrid Prophet+XGBoost models
"""
import pandas as pd
import numpy as np
from typing import Dict, Any, Optional
import structlog
logger = structlog.get_logger()
class ModelSelector:
"""
Intelligent model selection based on data characteristics.
Decision Criteria:
- Data size: Hybrid needs more data (min 90 days)
- Complexity: High variance benefits from XGBoost
- Seasonality strength: Weak seasonality benefits from XGBoost
- Historical performance: Compare models on validation set
"""
# Thresholds for model selection
MIN_DATA_POINTS_HYBRID = 90 # Minimum data points for hybrid
HIGH_VARIANCE_THRESHOLD = 0.5 # CV > 0.5 suggests complex patterns
LOW_SEASONALITY_THRESHOLD = 0.3 # Weak seasonal patterns
HYBRID_IMPROVEMENT_THRESHOLD = 0.05 # 5% MAPE improvement to justify hybrid
def __init__(self):
pass
def select_model_type(
self,
df: pd.DataFrame,
product_category: str = "unknown",
force_prophet: bool = False,
force_hybrid: bool = False
) -> str:
"""
Select best model type based on data characteristics.
Args:
df: Training data with 'y' column
product_category: Product category (bread, pastries, etc.)
force_prophet: Force Prophet-only model
force_hybrid: Force hybrid model
Returns:
"prophet" or "hybrid"
"""
# Honor forced selections
if force_prophet:
logger.info("Prophet-only model forced by configuration")
return "prophet"
if force_hybrid:
logger.info("Hybrid model forced by configuration")
return "hybrid"
# Check minimum data requirements
if len(df) < self.MIN_DATA_POINTS_HYBRID:
logger.info(
"Insufficient data for hybrid model, using Prophet",
data_points=len(df),
min_required=self.MIN_DATA_POINTS_HYBRID
)
return "prophet"
# Calculate data characteristics
characteristics = self._analyze_data_characteristics(df)
# Decision logic
score_hybrid = 0
score_prophet = 0
# Factor 1: Data complexity (variance)
if characteristics['coefficient_of_variation'] > self.HIGH_VARIANCE_THRESHOLD:
score_hybrid += 2
logger.debug("High variance detected, favoring hybrid", cv=characteristics['coefficient_of_variation'])
else:
score_prophet += 1
# Factor 2: Seasonality strength
if characteristics['seasonality_strength'] < self.LOW_SEASONALITY_THRESHOLD:
score_hybrid += 2
logger.debug("Weak seasonality detected, favoring hybrid", strength=characteristics['seasonality_strength'])
else:
score_prophet += 1
# Factor 3: Data size (more data = better for hybrid)
if len(df) > 180:
score_hybrid += 1
elif len(df) < 120:
score_prophet += 1
# Factor 4: Product category considerations
if product_category in ['seasonal', 'cakes']:
# Event-driven products benefit from XGBoost pattern learning
score_hybrid += 1
elif product_category in ['bread', 'savory']:
# Stable products work well with Prophet
score_prophet += 1
# Factor 5: Zero ratio (sparse data)
if characteristics['zero_ratio'] > 0.3:
# High zero ratio suggests difficult forecasting, hybrid might help
score_hybrid += 1
# Make decision
selected_model = "hybrid" if score_hybrid > score_prophet else "prophet"
logger.info(
"Model selection complete",
selected_model=selected_model,
score_hybrid=score_hybrid,
score_prophet=score_prophet,
data_points=len(df),
cv=characteristics['coefficient_of_variation'],
seasonality=characteristics['seasonality_strength'],
category=product_category
)
return selected_model
def _analyze_data_characteristics(self, df: pd.DataFrame) -> Dict[str, float]:
"""
Analyze time series characteristics.
Args:
df: DataFrame with 'y' column (sales data)
Returns:
Dictionary with data characteristics
"""
y = df['y'].values
# Coefficient of variation
cv = np.std(y) / np.mean(y) if np.mean(y) > 0 else 0
# Zero ratio
zero_ratio = (y == 0).sum() / len(y)
# Seasonality strength using autocorrelation at key lags (7 days, 30 days)
# This better captures periodic patterns without using future data
if len(df) >= 14:
# Calculate autocorrelation at weekly lag (7 days)
# Higher autocorrelation indicates stronger weekly patterns
try:
weekly_autocorr = pd.Series(y).autocorr(lag=7) if len(y) > 7 else 0
# Calculate autocorrelation at monthly lag if enough data
monthly_autocorr = pd.Series(y).autocorr(lag=30) if len(y) > 30 else 0
# Combine autocorrelations (weekly weighted more for bakery data)
seasonality_strength = abs(weekly_autocorr) * 0.7 + abs(monthly_autocorr) * 0.3
# Ensure in valid range [0, 1]
seasonality_strength = max(0.0, min(1.0, seasonality_strength))
except Exception:
# Fallback to simpler calculation if autocorrelation fails
seasonality_strength = 0.5
else:
seasonality_strength = 0.5 # Default
# Trend strength
if len(df) >= 30:
from scipy import stats
x = np.arange(len(y))
slope, _, r_value, _, _ = stats.linregress(x, y)
trend_strength = abs(r_value)
else:
trend_strength = 0
return {
'coefficient_of_variation': float(cv),
'zero_ratio': float(zero_ratio),
'seasonality_strength': float(seasonality_strength),
'trend_strength': float(trend_strength),
'mean': float(np.mean(y)),
'std': float(np.std(y))
}
def compare_models(
self,
prophet_metrics: Dict[str, float],
hybrid_metrics: Dict[str, float]
) -> str:
"""
Compare Prophet and Hybrid model performance.
Args:
prophet_metrics: Prophet model metrics (with 'mape' key)
hybrid_metrics: Hybrid model metrics (with 'mape' key)
Returns:
"prophet" or "hybrid" based on better performance
"""
prophet_mape = prophet_metrics.get('mape', float('inf'))
hybrid_mape = hybrid_metrics.get('mape', float('inf'))
# Calculate improvement
if prophet_mape > 0:
improvement = (prophet_mape - hybrid_mape) / prophet_mape
else:
improvement = 0
# Hybrid must improve by at least threshold to justify complexity
if improvement >= self.HYBRID_IMPROVEMENT_THRESHOLD:
logger.info(
"Hybrid model selected based on performance",
prophet_mape=prophet_mape,
hybrid_mape=hybrid_mape,
improvement=f"{improvement*100:.1f}%"
)
return "hybrid"
else:
logger.info(
"Prophet model selected (hybrid improvement insufficient)",
prophet_mape=prophet_mape,
hybrid_mape=hybrid_mape,
improvement=f"{improvement*100:.1f}%"
)
return "prophet"
def should_use_hybrid_model(
df: pd.DataFrame,
product_category: str = "unknown",
tenant_settings: Dict[str, Any] = None
) -> bool:
"""
Convenience function to determine if hybrid model should be used.
Args:
df: Training data
product_category: Product category
tenant_settings: Optional tenant-specific settings
Returns:
True if hybrid model should be used, False otherwise
"""
selector = ModelSelector()
# Check tenant settings
force_prophet = tenant_settings.get('force_prophet_only', False) if tenant_settings else False
force_hybrid = tenant_settings.get('force_hybrid', False) if tenant_settings else False
selected = selector.select_model_type(
df=df,
product_category=product_category,
force_prophet=force_prophet,
force_hybrid=force_hybrid
)
return selected == "hybrid"