Improve AI logic
This commit is contained in:
@@ -17,6 +17,8 @@ from shared.database.base import create_database_manager
|
||||
from shared.database.transactions import transactional
|
||||
from shared.database.exceptions import DatabaseError
|
||||
from app.core.config import settings
|
||||
from app.ml.enhanced_features import AdvancedFeatureEngineer
|
||||
import holidays
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -26,16 +28,67 @@ class EnhancedBakeryDataProcessor:
|
||||
Integrates date alignment, data cleaning, feature engineering, and preparation for ML models.
|
||||
"""
|
||||
|
||||
def __init__(self, database_manager=None):
|
||||
def __init__(self, database_manager=None, region: str = 'MD'):
|
||||
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
self.scalers = {} # Store scalers for each feature
|
||||
self.imputers = {} # Store imputers for missing value handling
|
||||
self.date_alignment_service = DateAlignmentService()
|
||||
self.feature_engineer = AdvancedFeatureEngineer()
|
||||
self.region = region # Region for holidays (MD=Madrid, PV=Basque, etc.)
|
||||
self.spain_holidays = holidays.Spain(prov=region) # Initialize holidays library
|
||||
|
||||
def get_scalers(self) -> Dict[str, Any]:
|
||||
"""Return the scalers/normalization parameters for use during prediction"""
|
||||
return self.scalers.copy()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _extract_numeric_from_dict(value: Any) -> Optional[float]:
|
||||
"""
|
||||
Robust extraction of numeric values from complex data structures.
|
||||
Handles various dict structures that might come from external APIs.
|
||||
|
||||
Args:
|
||||
value: Any value that might be a dict, numeric, or other type
|
||||
|
||||
Returns:
|
||||
Numeric value as float, or None if extraction fails
|
||||
"""
|
||||
# If already numeric, return it
|
||||
if isinstance(value, (int, float)) and not isinstance(value, bool):
|
||||
return float(value)
|
||||
|
||||
# If it's a dict, try multiple extraction strategies
|
||||
if isinstance(value, dict):
|
||||
# Strategy 1: Try common keys
|
||||
for key in ['value', 'data', 'result', 'amount', 'count', 'number', 'val']:
|
||||
if key in value:
|
||||
extracted = value[key]
|
||||
# Recursively extract if nested
|
||||
if isinstance(extracted, dict):
|
||||
return EnhancedBakeryDataProcessor._extract_numeric_from_dict(extracted)
|
||||
elif isinstance(extracted, (int, float)) and not isinstance(extracted, bool):
|
||||
return float(extracted)
|
||||
|
||||
# Strategy 2: Try to find first numeric value in dict
|
||||
for v in value.values():
|
||||
if isinstance(v, (int, float)) and not isinstance(v, bool):
|
||||
return float(v)
|
||||
elif isinstance(v, dict):
|
||||
# Recursively try nested dicts
|
||||
result = EnhancedBakeryDataProcessor._extract_numeric_from_dict(v)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Strategy 3: Try to convert string to numeric
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# If all strategies fail, return None (will be converted to NaN)
|
||||
return None
|
||||
|
||||
async def _get_repositories(self, session):
|
||||
"""Initialize repositories with session"""
|
||||
return {
|
||||
@@ -117,9 +170,12 @@ class EnhancedBakeryDataProcessor:
|
||||
daily_sales = self._merge_weather_features(daily_sales, weather_data)
|
||||
daily_sales = self._merge_traffic_features(daily_sales, traffic_data)
|
||||
|
||||
# Step 6: Engineer additional features
|
||||
# Step 6: Engineer basic features
|
||||
daily_sales = self._engineer_features(daily_sales)
|
||||
|
||||
|
||||
# Step 6b: Add advanced features (lagged, rolling, cyclical, interactions, trends)
|
||||
daily_sales = self._add_advanced_features(daily_sales)
|
||||
|
||||
# Step 7: Handle missing values
|
||||
daily_sales = self._handle_missing_values(daily_sales)
|
||||
|
||||
@@ -177,52 +233,73 @@ class EnhancedBakeryDataProcessor:
|
||||
async def prepare_prediction_features(self,
|
||||
future_dates: pd.DatetimeIndex,
|
||||
weather_forecast: pd.DataFrame = None,
|
||||
traffic_forecast: pd.DataFrame = None) -> pd.DataFrame:
|
||||
traffic_forecast: pd.DataFrame = None,
|
||||
historical_data: pd.DataFrame = None) -> pd.DataFrame:
|
||||
"""
|
||||
Create features for future predictions with proper date handling.
|
||||
|
||||
|
||||
Args:
|
||||
future_dates: Future dates to predict
|
||||
weather_forecast: Weather forecast data
|
||||
traffic_forecast: Traffic forecast data
|
||||
|
||||
historical_data: Historical data for creating lagged and rolling features
|
||||
|
||||
Returns:
|
||||
DataFrame with features for prediction
|
||||
"""
|
||||
try:
|
||||
# Create base future dataframe
|
||||
future_df = pd.DataFrame({'ds': future_dates})
|
||||
|
||||
|
||||
# Add temporal features
|
||||
future_df = self._add_temporal_features(
|
||||
future_df.rename(columns={'ds': 'date'})
|
||||
).rename(columns={'date': 'ds'})
|
||||
|
||||
|
||||
# Add weather features
|
||||
if weather_forecast is not None and not weather_forecast.empty:
|
||||
weather_features = weather_forecast.copy()
|
||||
if 'date' in weather_features.columns:
|
||||
weather_features = weather_features.rename(columns={'date': 'ds'})
|
||||
|
||||
|
||||
future_df = future_df.merge(weather_features, on='ds', how='left')
|
||||
|
||||
# Add traffic features
|
||||
|
||||
# Add traffic features
|
||||
if traffic_forecast is not None and not traffic_forecast.empty:
|
||||
traffic_features = traffic_forecast.copy()
|
||||
if 'date' in traffic_features.columns:
|
||||
traffic_features = traffic_features.rename(columns={'date': 'ds'})
|
||||
|
||||
|
||||
future_df = future_df.merge(traffic_features, on='ds', how='left')
|
||||
|
||||
# Engineer additional features
|
||||
|
||||
# Engineer basic features
|
||||
future_df = self._engineer_features(future_df.rename(columns={'ds': 'date'}))
|
||||
|
||||
# Add advanced features if historical data is provided
|
||||
if historical_data is not None and not historical_data.empty:
|
||||
# Combine historical and future data to calculate lagged/rolling features
|
||||
combined_df = pd.concat([
|
||||
historical_data.rename(columns={'ds': 'date'}),
|
||||
future_df
|
||||
], ignore_index=True).sort_values('date')
|
||||
|
||||
# Apply advanced features to combined data
|
||||
combined_df = self._add_advanced_features(combined_df)
|
||||
|
||||
# Extract only the future rows
|
||||
future_df = combined_df[combined_df['date'].isin(future_df['date'])].copy()
|
||||
else:
|
||||
# Without historical data, add advanced features with NaN for lags
|
||||
logger.warning("No historical data provided, lagged features will be NaN")
|
||||
future_df = self._add_advanced_features(future_df)
|
||||
|
||||
future_df = future_df.rename(columns={'date': 'ds'})
|
||||
|
||||
|
||||
# Handle missing values in future data
|
||||
future_df = self._handle_missing_values_future(future_df)
|
||||
|
||||
|
||||
return future_df
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error creating prediction features", error=str(e))
|
||||
# Return minimal features if error
|
||||
@@ -428,19 +505,40 @@ class EnhancedBakeryDataProcessor:
|
||||
for standard_name, possible_names in weather_mapping.items():
|
||||
for possible_name in possible_names:
|
||||
if possible_name in weather_clean.columns:
|
||||
weather_clean[standard_name] = pd.to_numeric(weather_clean[possible_name], errors='coerce')
|
||||
# Extract numeric values using robust helper function
|
||||
try:
|
||||
# Check if column contains dict-like objects
|
||||
has_dicts = weather_clean[possible_name].apply(lambda x: isinstance(x, dict)).any()
|
||||
|
||||
if has_dicts:
|
||||
logger.warning(f"Weather column {possible_name} contains dict objects, extracting numeric values")
|
||||
# Use robust extraction for all values
|
||||
weather_clean[standard_name] = weather_clean[possible_name].apply(
|
||||
self._extract_numeric_from_dict
|
||||
)
|
||||
else:
|
||||
# Direct numeric conversion for simple values
|
||||
weather_clean[standard_name] = pd.to_numeric(weather_clean[possible_name], errors='coerce')
|
||||
except Exception as e:
|
||||
logger.warning(f"Error converting weather column {possible_name}: {e}")
|
||||
# Fallback: try to extract from each value
|
||||
weather_clean[standard_name] = weather_clean[possible_name].apply(
|
||||
self._extract_numeric_from_dict
|
||||
)
|
||||
weather_features.append(standard_name)
|
||||
break
|
||||
|
||||
|
||||
# Keep only the features we found
|
||||
weather_clean = weather_clean[weather_features].copy()
|
||||
|
||||
|
||||
# Merge with sales data
|
||||
merged = daily_sales.merge(weather_clean, on='date', how='left')
|
||||
|
||||
|
||||
# Fill missing weather values with Madrid-appropriate defaults
|
||||
for feature, default_value in weather_defaults.items():
|
||||
if feature in merged.columns:
|
||||
# Ensure the column is numeric before filling
|
||||
merged[feature] = pd.to_numeric(merged[feature], errors='coerce')
|
||||
merged[feature] = merged[feature].fillna(default_value)
|
||||
|
||||
return merged
|
||||
@@ -494,16 +592,35 @@ class EnhancedBakeryDataProcessor:
|
||||
for standard_name, possible_names in traffic_mapping.items():
|
||||
for possible_name in possible_names:
|
||||
if possible_name in traffic_clean.columns:
|
||||
traffic_clean[standard_name] = pd.to_numeric(traffic_clean[possible_name], errors='coerce')
|
||||
# Extract numeric values using robust helper function
|
||||
try:
|
||||
# Check if column contains dict-like objects
|
||||
has_dicts = traffic_clean[possible_name].apply(lambda x: isinstance(x, dict)).any()
|
||||
|
||||
if has_dicts:
|
||||
logger.warning(f"Traffic column {possible_name} contains dict objects, extracting numeric values")
|
||||
# Use robust extraction for all values
|
||||
traffic_clean[standard_name] = traffic_clean[possible_name].apply(
|
||||
self._extract_numeric_from_dict
|
||||
)
|
||||
else:
|
||||
# Direct numeric conversion for simple values
|
||||
traffic_clean[standard_name] = pd.to_numeric(traffic_clean[possible_name], errors='coerce')
|
||||
except Exception as e:
|
||||
logger.warning(f"Error converting traffic column {possible_name}: {e}")
|
||||
# Fallback: try to extract from each value
|
||||
traffic_clean[standard_name] = traffic_clean[possible_name].apply(
|
||||
self._extract_numeric_from_dict
|
||||
)
|
||||
traffic_features.append(standard_name)
|
||||
break
|
||||
|
||||
|
||||
# Keep only the features we found
|
||||
traffic_clean = traffic_clean[traffic_features].copy()
|
||||
|
||||
|
||||
# Merge with sales data
|
||||
merged = daily_sales.merge(traffic_clean, on='date', how='left')
|
||||
|
||||
|
||||
# Fill missing traffic values with reasonable defaults
|
||||
traffic_defaults = {
|
||||
'traffic_volume': 100.0,
|
||||
@@ -511,9 +628,11 @@ class EnhancedBakeryDataProcessor:
|
||||
'congestion_level': 1.0, # Low congestion
|
||||
'average_speed': 30.0 # km/h typical for Madrid
|
||||
}
|
||||
|
||||
|
||||
for feature, default_value in traffic_defaults.items():
|
||||
if feature in merged.columns:
|
||||
# Ensure the column is numeric before filling
|
||||
merged[feature] = pd.to_numeric(merged[feature], errors='coerce')
|
||||
merged[feature] = merged[feature].fillna(default_value)
|
||||
|
||||
return merged
|
||||
@@ -530,17 +649,23 @@ class EnhancedBakeryDataProcessor:
|
||||
|
||||
# Weather-based features
|
||||
if 'temperature' in df.columns:
|
||||
# Ensure temperature is numeric (defensive check)
|
||||
df['temperature'] = pd.to_numeric(df['temperature'], errors='coerce').fillna(15.0)
|
||||
|
||||
df['temp_squared'] = df['temperature'] ** 2
|
||||
df['is_hot_day'] = (df['temperature'] > 25).astype(int) # Hot days in Madrid
|
||||
df['is_cold_day'] = (df['temperature'] < 10).astype(int) # Cold days
|
||||
df['is_pleasant_day'] = ((df['temperature'] >= 18) & (df['temperature'] <= 25)).astype(int)
|
||||
|
||||
|
||||
# Temperature categories for bakery products
|
||||
df['temp_category'] = pd.cut(df['temperature'],
|
||||
bins=[-np.inf, 5, 15, 25, np.inf],
|
||||
df['temp_category'] = pd.cut(df['temperature'],
|
||||
bins=[-np.inf, 5, 15, 25, np.inf],
|
||||
labels=[0, 1, 2, 3]).astype(int)
|
||||
|
||||
|
||||
if 'precipitation' in df.columns:
|
||||
# Ensure precipitation is numeric (defensive check)
|
||||
df['precipitation'] = pd.to_numeric(df['precipitation'], errors='coerce').fillna(0.0)
|
||||
|
||||
df['is_rainy_day'] = (df['precipitation'] > 0.1).astype(int)
|
||||
df['is_heavy_rain'] = (df['precipitation'] > 10).astype(int)
|
||||
df['rain_intensity'] = pd.cut(df['precipitation'],
|
||||
@@ -549,10 +674,13 @@ class EnhancedBakeryDataProcessor:
|
||||
|
||||
# Traffic-based features with NaN protection
|
||||
if 'traffic_volume' in df.columns:
|
||||
# Ensure traffic_volume is numeric (defensive check)
|
||||
df['traffic_volume'] = pd.to_numeric(df['traffic_volume'], errors='coerce').fillna(100.0)
|
||||
|
||||
# Calculate traffic quantiles for relative measures
|
||||
q75 = df['traffic_volume'].quantile(0.75)
|
||||
q25 = df['traffic_volume'].quantile(0.25)
|
||||
|
||||
|
||||
df['high_traffic'] = (df['traffic_volume'] > q75).astype(int)
|
||||
df['low_traffic'] = (df['traffic_volume'] < q25).astype(int)
|
||||
|
||||
@@ -578,7 +706,15 @@ class EnhancedBakeryDataProcessor:
|
||||
|
||||
# Fill any remaining NaN values
|
||||
df['traffic_normalized'] = df['traffic_normalized'].fillna(0.0)
|
||||
|
||||
|
||||
# Ensure other weather features are numeric if they exist
|
||||
for weather_col in ['humidity', 'wind_speed', 'pressure', 'pedestrian_count', 'congestion_level', 'average_speed']:
|
||||
if weather_col in df.columns:
|
||||
df[weather_col] = pd.to_numeric(df[weather_col], errors='coerce').fillna(
|
||||
{'humidity': 60.0, 'wind_speed': 5.0, 'pressure': 1013.0,
|
||||
'pedestrian_count': 50.0, 'congestion_level': 1.0, 'average_speed': 30.0}.get(weather_col, 0.0)
|
||||
)
|
||||
|
||||
# Interaction features - bakery specific
|
||||
if 'is_weekend' in df.columns and 'temperature' in df.columns:
|
||||
df['weekend_temp_interaction'] = df['is_weekend'] * df['temperature']
|
||||
@@ -619,7 +755,39 @@ class EnhancedBakeryDataProcessor:
|
||||
column=col,
|
||||
nan_count=nan_count)
|
||||
df[col] = df[col].fillna(0.0)
|
||||
|
||||
|
||||
return df
|
||||
|
||||
def _add_advanced_features(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Add advanced features using AdvancedFeatureEngineer.
|
||||
Includes lagged features, rolling statistics, cyclical encoding, and trend features.
|
||||
"""
|
||||
df = df.copy()
|
||||
|
||||
logger.info("Adding advanced features (lagged, rolling, cyclical, trends)")
|
||||
|
||||
# Reset feature engineer to clear previous features
|
||||
self.feature_engineer = AdvancedFeatureEngineer()
|
||||
|
||||
# Create all advanced features at once
|
||||
df = self.feature_engineer.create_all_features(
|
||||
df,
|
||||
date_column='date',
|
||||
include_lags=True,
|
||||
include_rolling=True,
|
||||
include_interactions=True,
|
||||
include_cyclical=True
|
||||
)
|
||||
|
||||
# Fill NA values from lagged and rolling features
|
||||
df = self.feature_engineer.fill_na_values(df, strategy='forward_backward')
|
||||
|
||||
# Store created feature columns for later reference
|
||||
created_features = self.feature_engineer.get_feature_columns()
|
||||
logger.info(f"Added {len(created_features)} advanced features",
|
||||
features=created_features[:10]) # Log first 10 for brevity
|
||||
|
||||
return df
|
||||
|
||||
def _handle_missing_values(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
@@ -733,46 +901,83 @@ class EnhancedBakeryDataProcessor:
|
||||
return 4 # Autumn
|
||||
|
||||
def _is_spanish_holiday(self, date: datetime) -> bool:
|
||||
"""Check if a date is a major Spanish holiday"""
|
||||
month_day = (date.month, date.day)
|
||||
|
||||
# Major Spanish holidays that affect bakery sales
|
||||
spanish_holidays = [
|
||||
(1, 1), # New Year
|
||||
(1, 6), # Epiphany (Reyes)
|
||||
(5, 1), # Labour Day
|
||||
(8, 15), # Assumption
|
||||
(10, 12), # National Day
|
||||
(11, 1), # All Saints
|
||||
(12, 6), # Constitution
|
||||
(12, 8), # Immaculate Conception
|
||||
(12, 25), # Christmas
|
||||
(5, 15), # San Isidro (Madrid patron saint)
|
||||
(5, 2), # Madrid Community Day
|
||||
]
|
||||
|
||||
return month_day in spanish_holidays
|
||||
"""
|
||||
Check if a date is a Spanish holiday using holidays library.
|
||||
Supports dynamic Easter calculation and regional holidays.
|
||||
"""
|
||||
try:
|
||||
# Convert to date if datetime
|
||||
if isinstance(date, datetime):
|
||||
date = date.date()
|
||||
elif isinstance(date, pd.Timestamp):
|
||||
date = date.date()
|
||||
|
||||
# Check if date is in holidays
|
||||
return date in self.spain_holidays
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking holiday status for {date}: {e}")
|
||||
# Fallback to checking basic holidays
|
||||
month_day = (date.month, date.day)
|
||||
basic_holidays = [
|
||||
(1, 1), (1, 6), (5, 1), (8, 15), (10, 12),
|
||||
(11, 1), (12, 6), (12, 8), (12, 25)
|
||||
]
|
||||
return month_day in basic_holidays
|
||||
|
||||
def _is_school_holiday(self, date: datetime) -> bool:
|
||||
"""Check if a date is during school holidays (approximate)"""
|
||||
month = date.month
|
||||
|
||||
# Approximate Spanish school holiday periods
|
||||
# Summer holidays (July-August)
|
||||
if month in [7, 8]:
|
||||
return True
|
||||
|
||||
# Christmas holidays (mid December to early January)
|
||||
if month == 12 and date.day >= 20:
|
||||
return True
|
||||
if month == 1 and date.day <= 10:
|
||||
return True
|
||||
|
||||
# Easter holidays (approximate - early April)
|
||||
if month == 4 and date.day <= 15:
|
||||
return True
|
||||
|
||||
return False
|
||||
"""
|
||||
Check if a date is during school holidays in Spain.
|
||||
Uses dynamic Easter calculation and standard Spanish school calendar.
|
||||
"""
|
||||
try:
|
||||
from datetime import timedelta
|
||||
import holidays as hol
|
||||
|
||||
# Convert to date if datetime
|
||||
if isinstance(date, datetime):
|
||||
check_date = date.date()
|
||||
elif isinstance(date, pd.Timestamp):
|
||||
check_date = date.date()
|
||||
else:
|
||||
check_date = date
|
||||
|
||||
month = check_date.month
|
||||
day = check_date.day
|
||||
|
||||
# Summer holidays (July 1 - August 31)
|
||||
if month in [7, 8]:
|
||||
return True
|
||||
|
||||
# Christmas holidays (December 23 - January 7)
|
||||
if (month == 12 and day >= 23) or (month == 1 and day <= 7):
|
||||
return True
|
||||
|
||||
# Easter/Spring break (Semana Santa)
|
||||
# Calculate Easter for this year
|
||||
year = check_date.year
|
||||
spain_hol = hol.Spain(years=year, prov=self.region)
|
||||
|
||||
# Find Easter dates (Viernes Santo - Good Friday, and nearby days)
|
||||
# Easter break typically spans 1 week before and after Easter Sunday
|
||||
for holiday_date, holiday_name in spain_hol.items():
|
||||
if 'viernes santo' in holiday_name.lower() or 'easter' in holiday_name.lower():
|
||||
# Easter break: 7 days before and 7 days after
|
||||
easter_start = holiday_date - timedelta(days=7)
|
||||
easter_end = holiday_date + timedelta(days=7)
|
||||
if easter_start <= check_date <= easter_end:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking school holiday for {date}: {e}")
|
||||
# Fallback to simple approximation
|
||||
month = date.month if hasattr(date, 'month') else date.month
|
||||
day = date.day if hasattr(date, 'day') else date.day
|
||||
return (month in [7, 8] or
|
||||
(month == 12 and day >= 23) or
|
||||
(month == 1 and day <= 7) or
|
||||
(month == 4 and 1 <= day <= 15)) # Approximate Easter
|
||||
|
||||
async def calculate_feature_importance(self,
|
||||
model_data: pd.DataFrame,
|
||||
|
||||
Reference in New Issue
Block a user