Improve AI logic
This commit is contained in:
@@ -25,20 +25,52 @@ class BakeryPredictor:
|
||||
Advanced predictor for bakery demand forecasting with dependency injection
|
||||
Handles Prophet models and business-specific logic
|
||||
"""
|
||||
|
||||
def __init__(self, database_manager=None):
|
||||
|
||||
def __init__(self, database_manager=None, use_dynamic_rules=True):
|
||||
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
self.model_cache = {}
|
||||
self.business_rules = BakeryBusinessRules()
|
||||
self.use_dynamic_rules = use_dynamic_rules
|
||||
|
||||
if use_dynamic_rules:
|
||||
from app.ml.dynamic_rules_engine import DynamicRulesEngine
|
||||
from shared.clients.ai_insights_client import AIInsightsClient
|
||||
self.rules_engine = DynamicRulesEngine()
|
||||
self.ai_insights_client = AIInsightsClient(
|
||||
base_url=settings.AI_INSIGHTS_SERVICE_URL or "http://ai-insights-service:8000"
|
||||
)
|
||||
else:
|
||||
self.business_rules = BakeryBusinessRules()
|
||||
|
||||
class BakeryForecaster:
|
||||
"""
|
||||
Enhanced forecaster that integrates with repository pattern
|
||||
Uses enhanced features from training service for predictions
|
||||
"""
|
||||
|
||||
def __init__(self, database_manager=None):
|
||||
|
||||
def __init__(self, database_manager=None, use_enhanced_features=True):
|
||||
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
self.predictor = BakeryPredictor(database_manager)
|
||||
self.use_enhanced_features = use_enhanced_features
|
||||
|
||||
if use_enhanced_features:
|
||||
# Import enhanced data processor from training service
|
||||
import sys
|
||||
import os
|
||||
# Add training service to path
|
||||
training_path = os.path.join(os.path.dirname(__file__), '../../../training')
|
||||
if training_path not in sys.path:
|
||||
sys.path.insert(0, training_path)
|
||||
|
||||
try:
|
||||
from app.ml.data_processor import EnhancedBakeryDataProcessor
|
||||
self.data_processor = EnhancedBakeryDataProcessor(database_manager)
|
||||
logger.info("Enhanced features enabled for forecasting")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import EnhancedBakeryDataProcessor: {e}, falling back to basic features")
|
||||
self.use_enhanced_features = False
|
||||
self.data_processor = None
|
||||
else:
|
||||
self.data_processor = None
|
||||
|
||||
async def generate_forecast_with_repository(self, tenant_id: str, inventory_product_id: str,
|
||||
forecast_date: date, model_id: str = None) -> Dict[str, Any]:
|
||||
@@ -110,45 +142,87 @@ class BakeryForecaster:
|
||||
logger.error("Error generating base prediction", error=str(e))
|
||||
raise
|
||||
|
||||
def _prepare_prophet_dataframe(self, features: Dict[str, Any]) -> pd.DataFrame:
|
||||
"""Convert features to Prophet-compatible DataFrame"""
|
||||
|
||||
async def _prepare_prophet_dataframe(self, features: Dict[str, Any],
|
||||
historical_data: pd.DataFrame = None) -> pd.DataFrame:
|
||||
"""
|
||||
Convert features to Prophet-compatible DataFrame.
|
||||
Uses enhanced features when available (60+ features vs basic 10).
|
||||
"""
|
||||
|
||||
try:
|
||||
# Create base DataFrame
|
||||
df = pd.DataFrame({
|
||||
'ds': [pd.to_datetime(features['date'])]
|
||||
})
|
||||
|
||||
# Add regressor features
|
||||
feature_mapping = {
|
||||
'temperature': 'temperature',
|
||||
'precipitation': 'precipitation',
|
||||
'humidity': 'humidity',
|
||||
'wind_speed': 'wind_speed',
|
||||
'traffic_volume': 'traffic_volume',
|
||||
'pedestrian_count': 'pedestrian_count'
|
||||
}
|
||||
|
||||
for feature_key, df_column in feature_mapping.items():
|
||||
if feature_key in features and features[feature_key] is not None:
|
||||
df[df_column] = float(features[feature_key])
|
||||
else:
|
||||
df[df_column] = 0.0
|
||||
|
||||
# Add categorical features
|
||||
df['day_of_week'] = int(features.get('day_of_week', 0))
|
||||
if self.use_enhanced_features and self.data_processor:
|
||||
# Use enhanced data processor from training service
|
||||
logger.info("Generating enhanced features for prediction")
|
||||
|
||||
# Create future date range
|
||||
future_dates = pd.DatetimeIndex([pd.to_datetime(features['date'])])
|
||||
|
||||
# Prepare weather forecast DataFrame
|
||||
weather_df = pd.DataFrame({
|
||||
'date': [pd.to_datetime(features['date'])],
|
||||
'temperature': [features.get('temperature', 15.0)],
|
||||
'precipitation': [features.get('precipitation', 0.0)],
|
||||
'humidity': [features.get('humidity', 60.0)],
|
||||
'wind_speed': [features.get('wind_speed', 5.0)],
|
||||
'pressure': [features.get('pressure', 1013.0)]
|
||||
})
|
||||
|
||||
# Use data processor to create ALL enhanced features
|
||||
df = await self.data_processor.prepare_prediction_features(
|
||||
future_dates=future_dates,
|
||||
weather_forecast=weather_df,
|
||||
traffic_forecast=None, # Will add when traffic forecasting is implemented
|
||||
historical_data=historical_data # For lagged features
|
||||
)
|
||||
|
||||
logger.info(f"Generated {len(df.columns)} enhanced features for prediction")
|
||||
return df
|
||||
|
||||
else:
|
||||
# Fallback to basic features
|
||||
logger.info("Using basic features for prediction")
|
||||
|
||||
# Create base DataFrame
|
||||
df = pd.DataFrame({
|
||||
'ds': [pd.to_datetime(features['date'])]
|
||||
})
|
||||
|
||||
# Add regressor features
|
||||
feature_mapping = {
|
||||
'temperature': 'temperature',
|
||||
'precipitation': 'precipitation',
|
||||
'humidity': 'humidity',
|
||||
'wind_speed': 'wind_speed',
|
||||
'traffic_volume': 'traffic_volume',
|
||||
'pedestrian_count': 'pedestrian_count'
|
||||
}
|
||||
|
||||
for feature_key, df_column in feature_mapping.items():
|
||||
if feature_key in features and features[feature_key] is not None:
|
||||
df[df_column] = float(features[feature_key])
|
||||
else:
|
||||
df[df_column] = 0.0
|
||||
|
||||
# Add categorical features
|
||||
df['day_of_week'] = int(features.get('day_of_week', 0))
|
||||
df['is_weekend'] = int(features.get('is_weekend', False))
|
||||
df['is_holiday'] = int(features.get('is_holiday', False))
|
||||
|
||||
# Business type
|
||||
business_type = features.get('business_type', 'individual')
|
||||
df['is_central_workshop'] = int(business_type == 'central_workshop')
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing Prophet dataframe: {e}, falling back to basic features")
|
||||
# Fallback to basic implementation on error
|
||||
df = pd.DataFrame({'ds': [pd.to_datetime(features['date'])]})
|
||||
df['temperature'] = features.get('temperature', 15.0)
|
||||
df['precipitation'] = features.get('precipitation', 0.0)
|
||||
df['is_weekend'] = int(features.get('is_weekend', False))
|
||||
df['is_holiday'] = int(features.get('is_holiday', False))
|
||||
|
||||
# Business type
|
||||
business_type = features.get('business_type', 'individual')
|
||||
df['is_central_workshop'] = int(business_type == 'central_workshop')
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error preparing Prophet dataframe", error=str(e))
|
||||
raise
|
||||
|
||||
def _add_uncertainty_bands(self, prediction: Dict[str, float],
|
||||
features: Dict[str, Any]) -> Dict[str, float]:
|
||||
@@ -225,80 +299,256 @@ class BakeryForecaster:
|
||||
|
||||
def _calculate_weekend_uncertainty(self, features: Dict[str, Any]) -> float:
|
||||
"""Calculate weekend-based uncertainty"""
|
||||
|
||||
|
||||
if features.get('is_weekend', False):
|
||||
return 0.1 # 10% additional uncertainty on weekends
|
||||
return 0.0
|
||||
|
||||
async def _get_dynamic_rules(self, tenant_id: str, inventory_product_id: str, rule_type: str) -> Dict[str, float]:
|
||||
"""
|
||||
Fetch learned dynamic rules from AI Insights Service.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
inventory_product_id: Product UUID
|
||||
rule_type: Type of rules (weather, temporal, holiday, etc.)
|
||||
|
||||
Returns:
|
||||
Dictionary of learned rules with factors
|
||||
"""
|
||||
try:
|
||||
from uuid import UUID
|
||||
|
||||
# Fetch latest rules insight for this product
|
||||
insights = await self.ai_insights_client.get_insights(
|
||||
tenant_id=UUID(tenant_id),
|
||||
filters={
|
||||
'category': 'forecasting',
|
||||
'actionable_only': False,
|
||||
'page_size': 100
|
||||
}
|
||||
)
|
||||
|
||||
if not insights or 'items' not in insights:
|
||||
return {}
|
||||
|
||||
# Find the most recent rules insight for this product
|
||||
for insight in insights['items']:
|
||||
if insight.get('source_model') == 'dynamic_rules_engine':
|
||||
metrics = insight.get('metrics_json', {})
|
||||
if metrics.get('inventory_product_id') == inventory_product_id:
|
||||
rules_data = metrics.get('rules', {})
|
||||
return rules_data.get(rule_type, {})
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch dynamic rules: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
class BakeryBusinessRules:
|
||||
"""
|
||||
Business rules for Spanish bakeries
|
||||
Applies domain-specific adjustments to predictions
|
||||
Supports both dynamic learned rules and hardcoded fallbacks
|
||||
"""
|
||||
|
||||
def apply_rules(self, prediction: Dict[str, float], features: Dict[str, Any],
|
||||
business_type: str) -> Dict[str, float]:
|
||||
"""Apply all business rules to prediction"""
|
||||
|
||||
|
||||
def __init__(self, use_dynamic_rules=False, ai_insights_client=None):
|
||||
self.use_dynamic_rules = use_dynamic_rules
|
||||
self.ai_insights_client = ai_insights_client
|
||||
self.rules_cache = {}
|
||||
|
||||
async def apply_rules(self, prediction: Dict[str, float], features: Dict[str, Any],
|
||||
business_type: str, tenant_id: str = None, inventory_product_id: str = None) -> Dict[str, float]:
|
||||
"""Apply all business rules to prediction (dynamic or hardcoded)"""
|
||||
|
||||
adjusted_prediction = prediction.copy()
|
||||
|
||||
|
||||
# Apply weather rules
|
||||
adjusted_prediction = self._apply_weather_rules(adjusted_prediction, features)
|
||||
|
||||
adjusted_prediction = await self._apply_weather_rules(
|
||||
adjusted_prediction, features, tenant_id, inventory_product_id
|
||||
)
|
||||
|
||||
# Apply time-based rules
|
||||
adjusted_prediction = self._apply_time_rules(adjusted_prediction, features)
|
||||
|
||||
adjusted_prediction = await self._apply_time_rules(
|
||||
adjusted_prediction, features, tenant_id, inventory_product_id
|
||||
)
|
||||
|
||||
# Apply business type rules
|
||||
adjusted_prediction = self._apply_business_type_rules(adjusted_prediction, business_type)
|
||||
|
||||
|
||||
# Apply Spanish-specific rules
|
||||
adjusted_prediction = self._apply_spanish_rules(adjusted_prediction, features)
|
||||
|
||||
|
||||
return adjusted_prediction
|
||||
|
||||
def _apply_weather_rules(self, prediction: Dict[str, float],
|
||||
features: Dict[str, Any]) -> Dict[str, float]:
|
||||
"""Apply weather-based business rules"""
|
||||
|
||||
# Rain reduces foot traffic
|
||||
precipitation = features.get('precipitation', 0)
|
||||
if precipitation > 0:
|
||||
rain_factor = settings.RAIN_IMPACT_FACTOR
|
||||
prediction["yhat"] *= rain_factor
|
||||
prediction["yhat_lower"] *= rain_factor
|
||||
prediction["yhat_upper"] *= rain_factor
|
||||
|
||||
# Extreme temperatures affect different products differently
|
||||
temperature = features.get('temperature')
|
||||
if temperature is not None:
|
||||
if temperature > settings.TEMPERATURE_THRESHOLD_HOT:
|
||||
# Hot weather reduces bread sales, increases cold drinks
|
||||
prediction["yhat"] *= 0.9
|
||||
elif temperature < settings.TEMPERATURE_THRESHOLD_COLD:
|
||||
# Cold weather increases hot beverage sales
|
||||
prediction["yhat"] *= 1.1
|
||||
|
||||
|
||||
async def _get_dynamic_rules(self, tenant_id: str, inventory_product_id: str, rule_type: str) -> Dict[str, float]:
|
||||
"""
|
||||
Fetch learned dynamic rules from AI Insights Service.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
inventory_product_id: Product UUID
|
||||
rule_type: Type of rules (weather, temporal, holiday, etc.)
|
||||
|
||||
Returns:
|
||||
Dictionary of learned rules with factors
|
||||
"""
|
||||
# Check cache first
|
||||
cache_key = f"{tenant_id}:{inventory_product_id}:{rule_type}"
|
||||
if cache_key in self.rules_cache:
|
||||
return self.rules_cache[cache_key]
|
||||
|
||||
try:
|
||||
from uuid import UUID
|
||||
|
||||
if not self.ai_insights_client:
|
||||
return {}
|
||||
|
||||
# Fetch latest rules insight for this product
|
||||
insights = await self.ai_insights_client.get_insights(
|
||||
tenant_id=UUID(tenant_id),
|
||||
filters={
|
||||
'category': 'forecasting',
|
||||
'actionable_only': False,
|
||||
'page_size': 100
|
||||
}
|
||||
)
|
||||
|
||||
if not insights or 'items' not in insights:
|
||||
return {}
|
||||
|
||||
# Find the most recent rules insight for this product
|
||||
for insight in insights['items']:
|
||||
if insight.get('source_model') == 'dynamic_rules_engine':
|
||||
metrics = insight.get('metrics_json', {})
|
||||
if metrics.get('inventory_product_id') == inventory_product_id:
|
||||
rules_data = metrics.get('rules', {})
|
||||
result = rules_data.get(rule_type, {})
|
||||
# Cache the result
|
||||
self.rules_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch dynamic rules: {e}")
|
||||
return {}
|
||||
|
||||
async def _apply_weather_rules(self, prediction: Dict[str, float],
|
||||
features: Dict[str, Any],
|
||||
tenant_id: str = None,
|
||||
inventory_product_id: str = None) -> Dict[str, float]:
|
||||
"""Apply weather-based business rules (dynamic or hardcoded fallback)"""
|
||||
|
||||
if self.use_dynamic_rules and tenant_id and inventory_product_id:
|
||||
try:
|
||||
# Fetch dynamic weather rules
|
||||
rules = await self._get_dynamic_rules(tenant_id, inventory_product_id, 'weather')
|
||||
|
||||
# Apply learned rain impact
|
||||
precipitation = features.get('precipitation', 0)
|
||||
if precipitation > 0:
|
||||
rain_factor = rules.get('rain_factor', settings.RAIN_IMPACT_FACTOR)
|
||||
prediction["yhat"] *= rain_factor
|
||||
prediction["yhat_lower"] *= rain_factor
|
||||
prediction["yhat_upper"] *= rain_factor
|
||||
|
||||
# Apply learned temperature impact
|
||||
temperature = features.get('temperature')
|
||||
if temperature is not None:
|
||||
if temperature > settings.TEMPERATURE_THRESHOLD_HOT:
|
||||
hot_factor = rules.get('temperature_hot_factor', 0.9)
|
||||
prediction["yhat"] *= hot_factor
|
||||
elif temperature < settings.TEMPERATURE_THRESHOLD_COLD:
|
||||
cold_factor = rules.get('temperature_cold_factor', 1.1)
|
||||
prediction["yhat"] *= cold_factor
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply dynamic weather rules, using fallback: {e}")
|
||||
# Fallback to hardcoded
|
||||
precipitation = features.get('precipitation', 0)
|
||||
if precipitation > 0:
|
||||
prediction["yhat"] *= settings.RAIN_IMPACT_FACTOR
|
||||
prediction["yhat_lower"] *= settings.RAIN_IMPACT_FACTOR
|
||||
prediction["yhat_upper"] *= settings.RAIN_IMPACT_FACTOR
|
||||
|
||||
temperature = features.get('temperature')
|
||||
if temperature is not None:
|
||||
if temperature > settings.TEMPERATURE_THRESHOLD_HOT:
|
||||
prediction["yhat"] *= 0.9
|
||||
elif temperature < settings.TEMPERATURE_THRESHOLD_COLD:
|
||||
prediction["yhat"] *= 1.1
|
||||
else:
|
||||
# Use hardcoded rules
|
||||
precipitation = features.get('precipitation', 0)
|
||||
if precipitation > 0:
|
||||
rain_factor = settings.RAIN_IMPACT_FACTOR
|
||||
prediction["yhat"] *= rain_factor
|
||||
prediction["yhat_lower"] *= rain_factor
|
||||
prediction["yhat_upper"] *= rain_factor
|
||||
|
||||
temperature = features.get('temperature')
|
||||
if temperature is not None:
|
||||
if temperature > settings.TEMPERATURE_THRESHOLD_HOT:
|
||||
prediction["yhat"] *= 0.9
|
||||
elif temperature < settings.TEMPERATURE_THRESHOLD_COLD:
|
||||
prediction["yhat"] *= 1.1
|
||||
|
||||
return prediction
|
||||
|
||||
def _apply_time_rules(self, prediction: Dict[str, float],
|
||||
features: Dict[str, Any]) -> Dict[str, float]:
|
||||
"""Apply time-based business rules"""
|
||||
|
||||
# Weekend adjustment
|
||||
if features.get('is_weekend', False):
|
||||
weekend_factor = settings.WEEKEND_ADJUSTMENT_FACTOR
|
||||
prediction["yhat"] *= weekend_factor
|
||||
prediction["yhat_lower"] *= weekend_factor
|
||||
prediction["yhat_upper"] *= weekend_factor
|
||||
|
||||
# Holiday adjustment
|
||||
if features.get('is_holiday', False):
|
||||
holiday_factor = settings.HOLIDAY_ADJUSTMENT_FACTOR
|
||||
prediction["yhat"] *= holiday_factor
|
||||
prediction["yhat_lower"] *= holiday_factor
|
||||
prediction["yhat_upper"] *= holiday_factor
|
||||
|
||||
async def _apply_time_rules(self, prediction: Dict[str, float],
|
||||
features: Dict[str, Any],
|
||||
tenant_id: str = None,
|
||||
inventory_product_id: str = None) -> Dict[str, float]:
|
||||
"""Apply time-based business rules (dynamic or hardcoded fallback)"""
|
||||
|
||||
if self.use_dynamic_rules and tenant_id and inventory_product_id:
|
||||
try:
|
||||
# Fetch dynamic temporal rules
|
||||
rules = await self._get_dynamic_rules(tenant_id, inventory_product_id, 'temporal')
|
||||
|
||||
# Apply learned weekend adjustment
|
||||
if features.get('is_weekend', False):
|
||||
weekend_factor = rules.get('weekend_factor', settings.WEEKEND_ADJUSTMENT_FACTOR)
|
||||
prediction["yhat"] *= weekend_factor
|
||||
prediction["yhat_lower"] *= weekend_factor
|
||||
prediction["yhat_upper"] *= weekend_factor
|
||||
|
||||
# Apply learned holiday adjustment
|
||||
if features.get('is_holiday', False):
|
||||
holiday_factor = rules.get('holiday_factor', settings.HOLIDAY_ADJUSTMENT_FACTOR)
|
||||
prediction["yhat"] *= holiday_factor
|
||||
prediction["yhat_lower"] *= holiday_factor
|
||||
prediction["yhat_upper"] *= holiday_factor
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply dynamic time rules, using fallback: {e}")
|
||||
# Fallback to hardcoded
|
||||
if features.get('is_weekend', False):
|
||||
prediction["yhat"] *= settings.WEEKEND_ADJUSTMENT_FACTOR
|
||||
prediction["yhat_lower"] *= settings.WEEKEND_ADJUSTMENT_FACTOR
|
||||
prediction["yhat_upper"] *= settings.WEEKEND_ADJUSTMENT_FACTOR
|
||||
|
||||
if features.get('is_holiday', False):
|
||||
prediction["yhat"] *= settings.HOLIDAY_ADJUSTMENT_FACTOR
|
||||
prediction["yhat_lower"] *= settings.HOLIDAY_ADJUSTMENT_FACTOR
|
||||
prediction["yhat_upper"] *= settings.HOLIDAY_ADJUSTMENT_FACTOR
|
||||
else:
|
||||
# Use hardcoded rules
|
||||
if features.get('is_weekend', False):
|
||||
weekend_factor = settings.WEEKEND_ADJUSTMENT_FACTOR
|
||||
prediction["yhat"] *= weekend_factor
|
||||
prediction["yhat_lower"] *= weekend_factor
|
||||
prediction["yhat_upper"] *= weekend_factor
|
||||
|
||||
if features.get('is_holiday', False):
|
||||
holiday_factor = settings.HOLIDAY_ADJUSTMENT_FACTOR
|
||||
prediction["yhat"] *= holiday_factor
|
||||
prediction["yhat_lower"] *= holiday_factor
|
||||
prediction["yhat_upper"] *= holiday_factor
|
||||
|
||||
return prediction
|
||||
|
||||
def _apply_business_type_rules(self, prediction: Dict[str, float],
|
||||
|
||||
Reference in New Issue
Block a user