Improve training code 2

This commit is contained in:
Urtzi Alfaro
2025-07-28 20:20:54 +02:00
parent 98f546af12
commit 7cd595df81
6 changed files with 229 additions and 153 deletions

View File

@@ -7,7 +7,7 @@ Handles data preparation, date alignment, cleaning, and feature engineering for
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
import logging
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
@@ -278,16 +278,23 @@ class BakeryDataProcessor:
return df
def _merge_weather_features(self,
daily_sales: pd.DataFrame,
weather_data: pd.DataFrame) -> pd.DataFrame:
"""Merge weather features with enhanced handling"""
daily_sales: pd.DataFrame,
weather_data: pd.DataFrame) -> pd.DataFrame:
"""Merge weather features with enhanced Madrid-specific handling"""
# ✅ FIX: Define weather_defaults OUTSIDE try block to fix scope error
weather_defaults = {
'temperature': 15.0,
'precipitation': 0.0,
'humidity': 60.0,
'wind_speed': 5.0,
'pressure': 1013.0
}
if weather_data.empty:
# Add default weather columns with Madrid-appropriate values
daily_sales['temperature'] = 15.0 # Average Madrid temperature
daily_sales['precipitation'] = 0.0 # Default no rain
daily_sales['humidity'] = 60.0 # Moderate humidity
daily_sales['wind_speed'] = 5.0 # Light wind
# Add default weather columns
for feature, default_value in weather_defaults.items():
daily_sales[feature] = default_value
return daily_sales
try:
@@ -297,14 +304,22 @@ class BakeryDataProcessor:
if 'date' not in weather_clean.columns and 'ds' in weather_clean.columns:
weather_clean = weather_clean.rename(columns={'ds': 'date'})
# ✅ FIX: Ensure timezone consistency
weather_clean['date'] = pd.to_datetime(weather_clean['date'])
daily_sales['date'] = pd.to_datetime(daily_sales['date'])
# Remove timezone info from both to make them compatible
if weather_clean['date'].dt.tz is not None:
weather_clean['date'] = weather_clean['date'].dt.tz_localize(None)
if daily_sales['date'].dt.tz is not None:
daily_sales['date'] = daily_sales['date'].dt.tz_localize(None)
# Map weather columns to standard names
weather_mapping = {
'temperature': ['temperature', 'temp', 'temperatura', 'temp_avg', 'temperature_avg'],
'precipitation': ['precipitation', 'rain', 'lluvia', 'precipitacion', 'rainfall'],
'temperature': ['temperature', 'temp', 'temperatura'],
'precipitation': ['precipitation', 'precip', 'rain', 'lluvia'],
'humidity': ['humidity', 'humedad', 'relative_humidity'],
'wind_speed': ['wind_speed', 'viento', 'wind', 'wind_avg'],
'wind_speed': ['wind_speed', 'viento', 'wind'],
'pressure': ['pressure', 'presion', 'atmospheric_pressure']
}
@@ -324,14 +339,6 @@ class BakeryDataProcessor:
merged = daily_sales.merge(weather_clean, on='date', how='left')
# Fill missing weather values with Madrid-appropriate defaults
weather_defaults = {
'temperature': 15.0,
'precipitation': 0.0,
'humidity': 60.0,
'wind_speed': 5.0,
'pressure': 1013.0
}
for feature, default_value in weather_defaults.items():
if feature in merged.columns:
merged[feature] = merged[feature].fillna(default_value)
@@ -340,10 +347,11 @@ class BakeryDataProcessor:
except Exception as e:
logger.warning(f"Error merging weather data: {e}")
# Add default weather columns if merge fails
# Add default weather columns if merge fails (weather_defaults now in scope)
for feature, default_value in weather_defaults.items():
daily_sales[feature] = default_value
return daily_sales
def _merge_traffic_features(self,
daily_sales: pd.DataFrame,
@@ -420,8 +428,8 @@ class BakeryDataProcessor:
# Temperature categories for bakery products
df['temp_category'] = pd.cut(df['temperature'],
bins=[-np.inf, 5, 15, 25, np.inf],
labels=[0, 1, 2, 3]).astype(int)
bins=[-np.inf, 5, 15, 25, np.inf],
labels=[0, 1, 2, 3]).astype(int)
if 'precipitation' in df.columns:
df['is_rainy_day'] = (df['precipitation'] > 0.1).astype(int)
@@ -430,7 +438,7 @@ class BakeryDataProcessor:
bins=[-0.1, 0, 2, 10, np.inf],
labels=[0, 1, 2, 3]).astype(int)
# Traffic-based features
# ✅ FIX: Traffic-based features with NaN protection
if 'traffic_volume' in df.columns:
# Calculate traffic quantiles for relative measures
q75 = df['traffic_volume'].quantile(0.75)
@@ -438,7 +446,21 @@ class BakeryDataProcessor:
df['high_traffic'] = (df['traffic_volume'] > q75).astype(int)
df['low_traffic'] = (df['traffic_volume'] < q25).astype(int)
df['traffic_normalized'] = (df['traffic_volume'] - df['traffic_volume'].mean()) / df['traffic_volume'].std()
# ✅ FIX: Safe normalization with NaN protection
traffic_std = df['traffic_volume'].std()
traffic_mean = df['traffic_volume'].mean()
if traffic_std > 0 and not pd.isna(traffic_std) and not pd.isna(traffic_mean):
# Normal case: valid standard deviation
df['traffic_normalized'] = (df['traffic_volume'] - traffic_mean) / traffic_std
else:
# Edge case: all values are the same or contain NaN
logger.warning("Traffic volume has zero standard deviation or contains NaN, using zeros for normalized values")
df['traffic_normalized'] = 0.0
# ✅ ADDITIONAL SAFETY: Fill any remaining NaN values
df['traffic_normalized'] = df['traffic_normalized'].fillna(0.0)
# Interaction features - bakery specific
if 'is_weekend' in df.columns and 'temperature' in df.columns:
@@ -465,30 +487,20 @@ class BakeryDataProcessor:
# Month-specific features for bakery seasonality
if 'month' in df.columns:
# Tourist season in Madrid (spring/summer)
df['is_tourist_season'] = df['month'].isin([4, 5, 6, 7, 8, 9]).astype(int)
# High-demand months (holidays, summer)
df['is_high_demand_month'] = df['month'].isin([6, 7, 8, 12]).astype(int)
# Christmas season (affects bakery sales significantly)
df['is_christmas_season'] = df['month'].isin([11, 12]).astype(int)
# Back-to-school/work season
df['is_back_to_work_season'] = df['month'].isin([1, 9]).astype(int)
# Spring/summer months
df['is_warm_season'] = df['month'].isin([4, 5, 6, 7, 8, 9]).astype(int)
# Lagged features (if we have enough data)
if len(df) > 7 and 'quantity' in df.columns:
# Rolling averages for trend detection
df['sales_7day_avg'] = df['quantity'].rolling(window=7, min_periods=3).mean()
df['sales_14day_avg'] = df['quantity'].rolling(window=14, min_periods=7).mean()
# Day-over-day changes
df['sales_change_1day'] = df['quantity'].diff()
df['sales_change_7day'] = df['quantity'].diff(7) # Week-over-week
# Fill NaN values for lagged features
df['sales_7day_avg'] = df['sales_7day_avg'].fillna(df['quantity'])
df['sales_14day_avg'] = df['sales_14day_avg'].fillna(df['quantity'])
df['sales_change_1day'] = df['sales_change_1day'].fillna(0)
df['sales_change_7day'] = df['sales_change_7day'].fillna(0)
# ✅ FINAL SAFETY CHECK: Remove any remaining NaN values
# Check for NaN values in all numeric columns and fill them
numeric_columns = df.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
if df[col].isna().any():
nan_count = df[col].isna().sum()
logger.warning(f"Found {nan_count} NaN values in column '{col}', filling with 0")
df[col] = df[col].fillna(0.0)
return df