Fix deadlock issues in training
This commit is contained in:
@@ -139,70 +139,154 @@ class EnhancedBakeryDataProcessor:
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
job_id=job_id)
|
job_id=job_id)
|
||||||
|
|
||||||
# Get database session and repositories
|
# Use provided session if available, otherwise create one
|
||||||
async with self.database_manager.get_session() as db_session:
|
if session is None:
|
||||||
repos = await self._get_repositories(db_session)
|
logger.debug("Creating new session for data preparation",
|
||||||
|
inventory_product_id=inventory_product_id)
|
||||||
|
async with self.database_manager.get_session() as db_session:
|
||||||
|
repos = await self._get_repositories(db_session)
|
||||||
|
|
||||||
|
# Log data preparation start if we have tracking info
|
||||||
|
if job_id and tenant_id:
|
||||||
|
logger.debug("About to update training log progress",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
job_id=job_id)
|
||||||
|
await repos['training_log'].update_log_progress(
|
||||||
|
job_id, 15, f"preparing_data_{inventory_product_id}", "running"
|
||||||
|
)
|
||||||
|
logger.debug("Updated training log progress",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
job_id=job_id)
|
||||||
|
# Commit the created session
|
||||||
|
await db_session.commit()
|
||||||
|
logger.debug("Committed session after data preparation progress update",
|
||||||
|
inventory_product_id=inventory_product_id)
|
||||||
|
else:
|
||||||
|
logger.debug("Using provided session for data preparation",
|
||||||
|
inventory_product_id=inventory_product_id)
|
||||||
|
# Use the provided session
|
||||||
|
repos = await self._get_repositories(session)
|
||||||
|
|
||||||
# Log data preparation start if we have tracking info
|
# Log data preparation start if we have tracking info
|
||||||
if job_id and tenant_id:
|
if job_id and tenant_id:
|
||||||
|
logger.debug("About to update training log progress with provided session",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
job_id=job_id)
|
||||||
await repos['training_log'].update_log_progress(
|
await repos['training_log'].update_log_progress(
|
||||||
job_id, 15, f"preparing_data_{inventory_product_id}", "running"
|
job_id, 15, f"preparing_data_{inventory_product_id}", "running"
|
||||||
)
|
)
|
||||||
# ✅ FIX: Commit the session to prevent deadlock with parent trainer session
|
logger.debug("Updated training log progress with provided session",
|
||||||
# The trainer has its own session, so we need to commit this update
|
inventory_product_id=inventory_product_id,
|
||||||
await db_session.commit()
|
job_id=job_id)
|
||||||
logger.debug("Committed session after data preparation progress update",
|
# Don't commit the provided session as the caller manages it
|
||||||
|
logger.debug("Updated progress with provided session",
|
||||||
inventory_product_id=inventory_product_id)
|
inventory_product_id=inventory_product_id)
|
||||||
|
|
||||||
# Step 1: Convert and validate sales data
|
logger.debug("Starting Step 1: Convert and validate sales data",
|
||||||
sales_clean = await self._process_sales_data(sales_data, inventory_product_id)
|
inventory_product_id=inventory_product_id)
|
||||||
|
# Step 1: Convert and validate sales data
|
||||||
|
sales_clean = await self._process_sales_data(sales_data, inventory_product_id)
|
||||||
|
logger.debug("Step 1 completed: Convert and validate sales data",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
sales_records=len(sales_clean))
|
||||||
|
|
||||||
# FIX: Ensure timezone awareness before any operations
|
logger.debug("Starting Step 2: Ensure timezone awareness",
|
||||||
sales_clean = self._ensure_timezone_aware(sales_clean)
|
inventory_product_id=inventory_product_id)
|
||||||
weather_data = self._ensure_timezone_aware(weather_data) if not weather_data.empty else weather_data
|
# FIX: Ensure timezone awareness before any operations
|
||||||
traffic_data = self._ensure_timezone_aware(traffic_data) if not traffic_data.empty else traffic_data
|
sales_clean = self._ensure_timezone_aware(sales_clean)
|
||||||
|
weather_data = self._ensure_timezone_aware(weather_data) if not weather_data.empty else weather_data
|
||||||
|
traffic_data = self._ensure_timezone_aware(traffic_data) if not traffic_data.empty else traffic_data
|
||||||
|
logger.debug("Step 2 completed: Ensure timezone awareness",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
weather_records=len(weather_data) if not weather_data.empty else 0,
|
||||||
|
traffic_records=len(traffic_data) if not traffic_data.empty else 0)
|
||||||
|
|
||||||
# Step 2: Apply date alignment if we have date constraints
|
logger.debug("Starting Step 3: Apply date alignment",
|
||||||
sales_clean = await self._apply_date_alignment(sales_clean, weather_data, traffic_data)
|
inventory_product_id=inventory_product_id)
|
||||||
|
# Step 2: Apply date alignment if we have date constraints
|
||||||
|
sales_clean = await self._apply_date_alignment(sales_clean, weather_data, traffic_data)
|
||||||
|
logger.debug("Step 3 completed: Apply date alignment",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
sales_records=len(sales_clean))
|
||||||
|
|
||||||
# Step 3: Aggregate to daily level
|
logger.debug("Starting Step 4: Aggregate to daily level",
|
||||||
daily_sales = await self._aggregate_daily_sales(sales_clean)
|
inventory_product_id=inventory_product_id)
|
||||||
|
# Step 3: Aggregate to daily level
|
||||||
|
daily_sales = await self._aggregate_daily_sales(sales_clean)
|
||||||
|
logger.debug("Step 4 completed: Aggregate to daily level",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
daily_records=len(daily_sales))
|
||||||
|
|
||||||
# Step 4: Add temporal features
|
logger.debug("Starting Step 5: Add temporal features",
|
||||||
daily_sales = self._add_temporal_features(daily_sales)
|
inventory_product_id=inventory_product_id)
|
||||||
|
# Step 4: Add temporal features
|
||||||
|
daily_sales = self._add_temporal_features(daily_sales)
|
||||||
|
logger.debug("Step 5 completed: Add temporal features",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
features_added=True)
|
||||||
|
|
||||||
# Step 5: Merge external data sources
|
logger.debug("Starting Step 6: Merge external data sources",
|
||||||
daily_sales = self._merge_weather_features(daily_sales, weather_data)
|
inventory_product_id=inventory_product_id)
|
||||||
daily_sales = self._merge_traffic_features(daily_sales, traffic_data)
|
# Step 5: Merge external data sources
|
||||||
|
daily_sales = self._merge_weather_features(daily_sales, weather_data)
|
||||||
|
daily_sales = self._merge_traffic_features(daily_sales, traffic_data)
|
||||||
|
logger.debug("Step 6 completed: Merge external data sources",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
merged_successfully=True)
|
||||||
|
|
||||||
# Step 6: Engineer basic features
|
logger.debug("Starting Step 7: Engineer basic features",
|
||||||
daily_sales = self._engineer_features(daily_sales)
|
inventory_product_id=inventory_product_id)
|
||||||
|
# Step 6: Engineer basic features
|
||||||
|
daily_sales = self._engineer_features(daily_sales)
|
||||||
|
logger.debug("Step 7 completed: Engineer basic features",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
feature_columns=len([col for col in daily_sales.columns if col not in ['date', 'quantity']]))
|
||||||
|
|
||||||
# Step 6b: Add advanced features (lagged, rolling, cyclical, interactions, trends)
|
logger.debug("Starting Step 8: Add advanced features",
|
||||||
daily_sales = self._add_advanced_features(daily_sales)
|
inventory_product_id=inventory_product_id)
|
||||||
|
# Step 6b: Add advanced features (lagged, rolling, cyclical, interactions, trends)
|
||||||
|
daily_sales = self._add_advanced_features(daily_sales)
|
||||||
|
logger.debug("Step 8 completed: Add advanced features",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
total_features=len(daily_sales.columns))
|
||||||
|
|
||||||
# Step 7: Handle missing values
|
logger.debug("Starting Step 9: Handle missing values",
|
||||||
daily_sales = self._handle_missing_values(daily_sales)
|
inventory_product_id=inventory_product_id)
|
||||||
|
# Step 7: Handle missing values
|
||||||
|
daily_sales = self._handle_missing_values(daily_sales)
|
||||||
|
logger.debug("Step 9 completed: Handle missing values",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
missing_values_handled=True)
|
||||||
|
|
||||||
# Step 8: Prepare for Prophet (rename columns and validate)
|
logger.debug("Starting Step 10: Prepare for Prophet format",
|
||||||
prophet_data = self._prepare_prophet_format(daily_sales)
|
inventory_product_id=inventory_product_id)
|
||||||
|
# Step 8: Prepare for Prophet (rename columns and validate)
|
||||||
|
prophet_data = self._prepare_prophet_format(daily_sales)
|
||||||
|
logger.debug("Step 10 completed: Prepare for Prophet format",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
prophet_records=len(prophet_data))
|
||||||
|
|
||||||
# Step 9: Store processing metadata if we have a tenant
|
logger.debug("Starting Step 11: Store processing metadata",
|
||||||
if tenant_id:
|
inventory_product_id=inventory_product_id)
|
||||||
await self._store_processing_metadata(
|
# Step 9: Store processing metadata if we have a tenant
|
||||||
repos, tenant_id, inventory_product_id, prophet_data, job_id
|
if tenant_id:
|
||||||
)
|
await self._store_processing_metadata(
|
||||||
|
repos, tenant_id, inventory_product_id, prophet_data, job_id, session
|
||||||
|
)
|
||||||
|
logger.debug("Step 11 completed: Store processing metadata",
|
||||||
|
inventory_product_id=inventory_product_id)
|
||||||
|
|
||||||
logger.info("Enhanced training data prepared successfully",
|
logger.info("Enhanced training data prepared successfully",
|
||||||
inventory_product_id=inventory_product_id,
|
inventory_product_id=inventory_product_id,
|
||||||
data_points=len(prophet_data))
|
data_points=len(prophet_data))
|
||||||
|
|
||||||
return prophet_data
|
return prophet_data
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error preparing enhanced training data",
|
logger.error("Error preparing enhanced training data",
|
||||||
inventory_product_id=inventory_product_id,
|
inventory_product_id=inventory_product_id,
|
||||||
error=str(e))
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _store_processing_metadata(self,
|
async def _store_processing_metadata(self,
|
||||||
@@ -210,7 +294,8 @@ class EnhancedBakeryDataProcessor:
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
inventory_product_id: str,
|
inventory_product_id: str,
|
||||||
processed_data: pd.DataFrame,
|
processed_data: pd.DataFrame,
|
||||||
job_id: str = None):
|
job_id: str = None,
|
||||||
|
session=None):
|
||||||
"""Store data processing metadata using repository"""
|
"""Store data processing metadata using repository"""
|
||||||
try:
|
try:
|
||||||
# Create processing metadata
|
# Create processing metadata
|
||||||
@@ -230,9 +315,12 @@ class EnhancedBakeryDataProcessor:
|
|||||||
await repos['training_log'].update_log_progress(
|
await repos['training_log'].update_log_progress(
|
||||||
job_id, 25, f"data_prepared_{inventory_product_id}", "running"
|
job_id, 25, f"data_prepared_{inventory_product_id}", "running"
|
||||||
)
|
)
|
||||||
# ✅ FIX: Commit after final progress update to prevent deadlock
|
# If we have a session and it's not managed elsewhere, commit it
|
||||||
await db_session.commit()
|
if session is not None:
|
||||||
logger.debug("Committed session after data preparation completion",
|
# Don't commit here as the caller will manage the session
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.debug("Data preparation metadata stored",
|
||||||
inventory_product_id=inventory_product_id)
|
inventory_product_id=inventory_product_id)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -358,69 +446,160 @@ class EnhancedBakeryDataProcessor:
|
|||||||
|
|
||||||
async def _process_sales_data(self, sales_data: pd.DataFrame, inventory_product_id: str) -> pd.DataFrame:
|
async def _process_sales_data(self, sales_data: pd.DataFrame, inventory_product_id: str) -> pd.DataFrame:
|
||||||
"""Process and clean sales data with enhanced validation"""
|
"""Process and clean sales data with enhanced validation"""
|
||||||
|
logger.debug("Starting sales data processing",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
total_records=len(sales_data),
|
||||||
|
columns=list(sales_data.columns))
|
||||||
sales_clean = sales_data.copy()
|
sales_clean = sales_data.copy()
|
||||||
|
|
||||||
|
logger.debug("Checking for date column existence",
|
||||||
|
inventory_product_id=inventory_product_id)
|
||||||
# Ensure date column exists and is datetime
|
# Ensure date column exists and is datetime
|
||||||
if 'date' not in sales_clean.columns:
|
if 'date' not in sales_clean.columns:
|
||||||
|
logger.error("Sales data must have a 'date' column",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
available_columns=list(sales_data.columns))
|
||||||
raise ValueError("Sales data must have a 'date' column")
|
raise ValueError("Sales data must have a 'date' column")
|
||||||
|
|
||||||
|
logger.debug("Converting date column to datetime",
|
||||||
|
inventory_product_id=inventory_product_id)
|
||||||
sales_clean['date'] = pd.to_datetime(sales_clean['date'])
|
sales_clean['date'] = pd.to_datetime(sales_clean['date'])
|
||||||
|
logger.debug("Date conversion completed",
|
||||||
|
inventory_product_id=inventory_product_id)
|
||||||
|
|
||||||
# Handle different quantity column names
|
# Handle different quantity column names
|
||||||
quantity_columns = ['quantity', 'quantity_sold', 'sales', 'units_sold']
|
quantity_columns = ['quantity', 'quantity_sold', 'sales', 'units_sold']
|
||||||
|
logger.debug("Looking for quantity column",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
quantity_columns=quantity_columns)
|
||||||
quantity_col = None
|
quantity_col = None
|
||||||
|
|
||||||
for col in quantity_columns:
|
for col in quantity_columns:
|
||||||
if col in sales_clean.columns:
|
if col in sales_clean.columns:
|
||||||
quantity_col = col
|
quantity_col = col
|
||||||
|
logger.debug("Found quantity column",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
quantity_column=col)
|
||||||
break
|
break
|
||||||
|
|
||||||
if quantity_col is None:
|
if quantity_col is None:
|
||||||
|
logger.error("Sales data must have one of the expected quantity columns",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
expected_columns=quantity_columns,
|
||||||
|
available_columns=list(sales_clean.columns))
|
||||||
raise ValueError(f"Sales data must have one of these columns: {quantity_columns}")
|
raise ValueError(f"Sales data must have one of these columns: {quantity_columns}")
|
||||||
|
|
||||||
# Standardize to 'quantity'
|
# Standardize to 'quantity'
|
||||||
if quantity_col != 'quantity':
|
if quantity_col != 'quantity':
|
||||||
|
logger.debug("Mapping quantity column",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
from_column=quantity_col,
|
||||||
|
to_column='quantity')
|
||||||
sales_clean['quantity'] = sales_clean[quantity_col]
|
sales_clean['quantity'] = sales_clean[quantity_col]
|
||||||
logger.info("Mapped quantity column",
|
logger.info("Mapped quantity column",
|
||||||
from_column=quantity_col,
|
from_column=quantity_col,
|
||||||
to_column='quantity')
|
to_column='quantity')
|
||||||
|
|
||||||
|
logger.debug("Converting quantity to numeric",
|
||||||
|
inventory_product_id=inventory_product_id)
|
||||||
sales_clean['quantity'] = pd.to_numeric(sales_clean['quantity'], errors='coerce')
|
sales_clean['quantity'] = pd.to_numeric(sales_clean['quantity'], errors='coerce')
|
||||||
|
logger.debug("Quantity conversion completed",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
non_numeric_count=sales_clean['quantity'].isna().sum())
|
||||||
|
|
||||||
# Remove rows with invalid quantities
|
# Remove rows with invalid quantities
|
||||||
|
logger.debug("Removing rows with invalid quantities",
|
||||||
|
inventory_product_id=inventory_product_id)
|
||||||
sales_clean = sales_clean.dropna(subset=['quantity'])
|
sales_clean = sales_clean.dropna(subset=['quantity'])
|
||||||
|
logger.debug("NaN rows removed",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
remaining_records=len(sales_clean))
|
||||||
|
|
||||||
sales_clean = sales_clean[sales_clean['quantity'] >= 0] # No negative sales
|
sales_clean = sales_clean[sales_clean['quantity'] >= 0] # No negative sales
|
||||||
|
logger.debug("Negative sales removed",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
remaining_records=len(sales_clean))
|
||||||
|
|
||||||
# Filter for the specific product if inventory_product_id column exists
|
# Filter for the specific product if inventory_product_id column exists
|
||||||
|
logger.debug("Checking for inventory_product_id column",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
has_inventory_column='inventory_product_id' in sales_clean.columns)
|
||||||
if 'inventory_product_id' in sales_clean.columns:
|
if 'inventory_product_id' in sales_clean.columns:
|
||||||
|
logger.debug("Filtering for specific product",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
products_in_data=sales_clean['inventory_product_id'].unique()[:5].tolist()) # Show first 5
|
||||||
|
original_count = len(sales_clean)
|
||||||
sales_clean = sales_clean[sales_clean['inventory_product_id'] == inventory_product_id]
|
sales_clean = sales_clean[sales_clean['inventory_product_id'] == inventory_product_id]
|
||||||
|
logger.debug("Product filtering completed",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
original_count=original_count,
|
||||||
|
filtered_count=len(sales_clean))
|
||||||
|
|
||||||
# Remove duplicate dates (keep the one with highest quantity)
|
# Remove duplicate dates (keep the one with highest quantity)
|
||||||
|
logger.debug("Removing duplicate dates",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
before_dedupe=len(sales_clean))
|
||||||
sales_clean = sales_clean.sort_values(['date', 'quantity'], ascending=[True, False])
|
sales_clean = sales_clean.sort_values(['date', 'quantity'], ascending=[True, False])
|
||||||
sales_clean = sales_clean.drop_duplicates(subset=['date'], keep='first')
|
sales_clean = sales_clean.drop_duplicates(subset=['date'], keep='first')
|
||||||
|
logger.debug("Duplicate dates removed",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
after_dedupe=len(sales_clean))
|
||||||
|
|
||||||
|
logger.debug("Sales data processing completed",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
final_records=len(sales_clean))
|
||||||
return sales_clean
|
return sales_clean
|
||||||
|
|
||||||
async def _aggregate_daily_sales(self, sales_data: pd.DataFrame) -> pd.DataFrame:
|
async def _aggregate_daily_sales(self, sales_data: pd.DataFrame) -> pd.DataFrame:
|
||||||
"""Aggregate sales to daily level with improved date handling"""
|
"""Aggregate sales to daily level with improved date handling"""
|
||||||
|
logger.debug("Starting daily sales aggregation",
|
||||||
|
input_records=len(sales_data),
|
||||||
|
columns=list(sales_data.columns))
|
||||||
|
|
||||||
if sales_data.empty:
|
if sales_data.empty:
|
||||||
|
logger.debug("Sales data is empty, returning empty DataFrame")
|
||||||
return pd.DataFrame(columns=['date', 'quantity'])
|
return pd.DataFrame(columns=['date', 'quantity'])
|
||||||
|
|
||||||
|
logger.debug("Starting groupby aggregation",
|
||||||
|
unique_dates=sales_data['date'].nunique(),
|
||||||
|
date_range=(sales_data['date'].min(), sales_data['date'].max()))
|
||||||
|
|
||||||
# Group by date and sum quantities
|
# Group by date and sum quantities
|
||||||
daily_sales = sales_data.groupby('date').agg({
|
daily_sales = sales_data.groupby('date').agg({
|
||||||
'quantity': 'sum'
|
'quantity': 'sum'
|
||||||
}).reset_index()
|
}).reset_index()
|
||||||
|
|
||||||
|
logger.debug("Groupby aggregation completed",
|
||||||
|
aggregated_records=len(daily_sales))
|
||||||
|
|
||||||
# Ensure we have data for all dates in the range (fill gaps with 0)
|
# Ensure we have data for all dates in the range (fill gaps with 0)
|
||||||
|
logger.debug("Creating full date range",
|
||||||
|
start_date=daily_sales['date'].min(),
|
||||||
|
end_date=daily_sales['date'].max())
|
||||||
date_range = pd.date_range(
|
date_range = pd.date_range(
|
||||||
start=daily_sales['date'].min(),
|
start=daily_sales['date'].min(),
|
||||||
end=daily_sales['date'].max(),
|
end=daily_sales['date'].max(),
|
||||||
freq='D'
|
freq='D'
|
||||||
)
|
)
|
||||||
|
logger.debug("Date range created",
|
||||||
|
total_dates=len(date_range))
|
||||||
|
|
||||||
full_date_df = pd.DataFrame({'date': date_range})
|
full_date_df = pd.DataFrame({'date': date_range})
|
||||||
|
logger.debug("Starting merge to fill missing dates",
|
||||||
|
full_date_records=len(full_date_df),
|
||||||
|
aggregated_records=len(daily_sales))
|
||||||
daily_sales = full_date_df.merge(daily_sales, on='date', how='left')
|
daily_sales = full_date_df.merge(daily_sales, on='date', how='left')
|
||||||
|
logger.debug("Missing date filling merge completed",
|
||||||
|
final_records=len(daily_sales))
|
||||||
|
|
||||||
daily_sales['quantity'] = daily_sales['quantity'].fillna(0) # Fill missing days with 0 sales
|
daily_sales['quantity'] = daily_sales['quantity'].fillna(0) # Fill missing days with 0 sales
|
||||||
|
logger.debug("NaN filling completed",
|
||||||
|
remaining_nan_count=daily_sales['quantity'].isna().sum(),
|
||||||
|
zero_filled_count=(daily_sales['quantity'] == 0).sum())
|
||||||
|
|
||||||
|
logger.debug("Daily sales aggregation completed",
|
||||||
|
final_records=len(daily_sales),
|
||||||
|
final_columns=len(daily_sales.columns))
|
||||||
|
|
||||||
return daily_sales
|
return daily_sales
|
||||||
|
|
||||||
@@ -466,6 +645,10 @@ class EnhancedBakeryDataProcessor:
|
|||||||
daily_sales: pd.DataFrame,
|
daily_sales: pd.DataFrame,
|
||||||
weather_data: pd.DataFrame) -> pd.DataFrame:
|
weather_data: pd.DataFrame) -> pd.DataFrame:
|
||||||
"""Merge weather features with enhanced Madrid-specific handling"""
|
"""Merge weather features with enhanced Madrid-specific handling"""
|
||||||
|
logger.debug("Starting weather features merge",
|
||||||
|
daily_sales_records=len(daily_sales),
|
||||||
|
weather_data_records=len(weather_data) if not weather_data.empty else 0,
|
||||||
|
weather_columns=list(weather_data.columns) if not weather_data.empty else [])
|
||||||
|
|
||||||
# Define weather_defaults OUTSIDE try block to fix scope error
|
# Define weather_defaults OUTSIDE try block to fix scope error
|
||||||
weather_defaults = {
|
weather_defaults = {
|
||||||
@@ -477,27 +660,38 @@ class EnhancedBakeryDataProcessor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if weather_data.empty:
|
if weather_data.empty:
|
||||||
|
logger.debug("Weather data is empty, adding default columns")
|
||||||
# Add default weather columns
|
# Add default weather columns
|
||||||
for feature, default_value in weather_defaults.items():
|
for feature, default_value in weather_defaults.items():
|
||||||
daily_sales[feature] = default_value
|
daily_sales[feature] = default_value
|
||||||
|
logger.debug("Default weather columns added",
|
||||||
|
features_added=list(weather_defaults.keys()))
|
||||||
return daily_sales
|
return daily_sales
|
||||||
|
|
||||||
try:
|
try:
|
||||||
weather_clean = weather_data.copy()
|
weather_clean = weather_data.copy()
|
||||||
|
logger.debug("Weather data copied",
|
||||||
|
records=len(weather_clean),
|
||||||
|
columns=list(weather_clean.columns))
|
||||||
|
|
||||||
# Standardize date column
|
# Standardize date column
|
||||||
if 'date' not in weather_clean.columns and 'ds' in weather_clean.columns:
|
if 'date' not in weather_clean.columns and 'ds' in weather_clean.columns:
|
||||||
|
logger.debug("Renaming ds column to date")
|
||||||
weather_clean = weather_clean.rename(columns={'ds': 'date'})
|
weather_clean = weather_clean.rename(columns={'ds': 'date'})
|
||||||
|
|
||||||
# CRITICAL FIX: Ensure both DataFrames have compatible datetime formats
|
# CRITICAL FIX: Ensure both DataFrames have compatible datetime formats
|
||||||
|
logger.debug("Converting weather data date column to datetime")
|
||||||
weather_clean['date'] = pd.to_datetime(weather_clean['date'])
|
weather_clean['date'] = pd.to_datetime(weather_clean['date'])
|
||||||
|
logger.debug("Converting daily sales date column to datetime")
|
||||||
daily_sales['date'] = pd.to_datetime(daily_sales['date'])
|
daily_sales['date'] = pd.to_datetime(daily_sales['date'])
|
||||||
|
|
||||||
# NEW FIX: Normalize both to timezone-naive datetime for merge compatibility
|
# NEW FIX: Normalize both to timezone-naive datetime for merge compatibility
|
||||||
if weather_clean['date'].dt.tz is not None:
|
if weather_clean['date'].dt.tz is not None:
|
||||||
|
logger.debug("Removing timezone from weather data")
|
||||||
weather_clean['date'] = weather_clean['date'].dt.tz_convert('UTC').dt.tz_localize(None)
|
weather_clean['date'] = weather_clean['date'].dt.tz_convert('UTC').dt.tz_localize(None)
|
||||||
|
|
||||||
if daily_sales['date'].dt.tz is not None:
|
if daily_sales['date'].dt.tz is not None:
|
||||||
|
logger.debug("Removing timezone from daily sales data")
|
||||||
daily_sales['date'] = daily_sales['date'].dt.tz_convert('UTC').dt.tz_localize(None)
|
daily_sales['date'] = daily_sales['date'].dt.tz_convert('UTC').dt.tz_localize(None)
|
||||||
|
|
||||||
# Map weather columns to standard names
|
# Map weather columns to standard names
|
||||||
@@ -510,14 +704,24 @@ class EnhancedBakeryDataProcessor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
weather_features = ['date']
|
weather_features = ['date']
|
||||||
|
logger.debug("Mapping weather columns",
|
||||||
|
mapping_attempts=list(weather_mapping.keys()))
|
||||||
|
|
||||||
for standard_name, possible_names in weather_mapping.items():
|
for standard_name, possible_names in weather_mapping.items():
|
||||||
for possible_name in possible_names:
|
for possible_name in possible_names:
|
||||||
if possible_name in weather_clean.columns:
|
if possible_name in weather_clean.columns:
|
||||||
|
logger.debug("Processing weather column",
|
||||||
|
standard_name=standard_name,
|
||||||
|
possible_name=possible_name,
|
||||||
|
records=len(weather_clean))
|
||||||
|
|
||||||
# Extract numeric values using robust helper function
|
# Extract numeric values using robust helper function
|
||||||
try:
|
try:
|
||||||
# Check if column contains dict-like objects
|
# Check if column contains dict-like objects
|
||||||
|
logger.debug("Checking for dict objects in weather column")
|
||||||
has_dicts = weather_clean[possible_name].apply(lambda x: isinstance(x, dict)).any()
|
has_dicts = weather_clean[possible_name].apply(lambda x: isinstance(x, dict)).any()
|
||||||
|
logger.debug("Dict object check completed",
|
||||||
|
has_dicts=has_dicts)
|
||||||
|
|
||||||
if has_dicts:
|
if has_dicts:
|
||||||
logger.warning(f"Weather column {possible_name} contains dict objects, extracting numeric values")
|
logger.warning(f"Weather column {possible_name} contains dict objects, extracting numeric values")
|
||||||
@@ -525,9 +729,14 @@ class EnhancedBakeryDataProcessor:
|
|||||||
weather_clean[standard_name] = weather_clean[possible_name].apply(
|
weather_clean[standard_name] = weather_clean[possible_name].apply(
|
||||||
self._extract_numeric_from_dict
|
self._extract_numeric_from_dict
|
||||||
)
|
)
|
||||||
|
logger.debug("Dict extraction completed for weather column",
|
||||||
|
extracted_column=standard_name,
|
||||||
|
extracted_count=weather_clean[standard_name].notna().sum())
|
||||||
else:
|
else:
|
||||||
# Direct numeric conversion for simple values
|
# Direct numeric conversion for simple values
|
||||||
|
logger.debug("Performing direct numeric conversion")
|
||||||
weather_clean[standard_name] = pd.to_numeric(weather_clean[possible_name], errors='coerce')
|
weather_clean[standard_name] = pd.to_numeric(weather_clean[possible_name], errors='coerce')
|
||||||
|
logger.debug("Direct numeric conversion completed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error converting weather column {possible_name}: {e}")
|
logger.warning(f"Error converting weather column {possible_name}: {e}")
|
||||||
# Fallback: try to extract from each value
|
# Fallback: try to extract from each value
|
||||||
@@ -535,28 +744,55 @@ class EnhancedBakeryDataProcessor:
|
|||||||
self._extract_numeric_from_dict
|
self._extract_numeric_from_dict
|
||||||
)
|
)
|
||||||
weather_features.append(standard_name)
|
weather_features.append(standard_name)
|
||||||
|
logger.debug("Added weather feature to list",
|
||||||
|
feature=standard_name)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Keep only the features we found
|
# Keep only the features we found
|
||||||
|
logger.debug("Selecting weather features",
|
||||||
|
selected_features=weather_features)
|
||||||
weather_clean = weather_clean[weather_features].copy()
|
weather_clean = weather_clean[weather_features].copy()
|
||||||
|
|
||||||
# Merge with sales data
|
# Merge with sales data
|
||||||
|
logger.debug("Starting merge operation",
|
||||||
|
daily_sales_rows=len(daily_sales),
|
||||||
|
weather_rows=len(weather_clean),
|
||||||
|
date_range_sales=(daily_sales['date'].min(), daily_sales['date'].max()) if len(daily_sales) > 0 else None,
|
||||||
|
date_range_weather=(weather_clean['date'].min(), weather_clean['date'].max()) if len(weather_clean) > 0 else None)
|
||||||
|
|
||||||
merged = daily_sales.merge(weather_clean, on='date', how='left')
|
merged = daily_sales.merge(weather_clean, on='date', how='left')
|
||||||
|
|
||||||
|
logger.debug("Merge completed",
|
||||||
|
merged_rows=len(merged),
|
||||||
|
merge_type='left')
|
||||||
|
|
||||||
# Fill missing weather values with Madrid-appropriate defaults
|
# Fill missing weather values with Madrid-appropriate defaults
|
||||||
|
logger.debug("Filling missing weather values",
|
||||||
|
features_to_fill=list(weather_defaults.keys()))
|
||||||
for feature, default_value in weather_defaults.items():
|
for feature, default_value in weather_defaults.items():
|
||||||
if feature in merged.columns:
|
if feature in merged.columns:
|
||||||
|
logger.debug("Processing feature for NaN fill",
|
||||||
|
feature=feature,
|
||||||
|
nan_count=merged[feature].isna().sum())
|
||||||
# Ensure the column is numeric before filling
|
# Ensure the column is numeric before filling
|
||||||
merged[feature] = pd.to_numeric(merged[feature], errors='coerce')
|
merged[feature] = pd.to_numeric(merged[feature], errors='coerce')
|
||||||
merged[feature] = merged[feature].fillna(default_value)
|
merged[feature] = merged[feature].fillna(default_value)
|
||||||
|
logger.debug("NaN fill completed for feature",
|
||||||
|
feature=feature,
|
||||||
|
final_nan_count=merged[feature].isna().sum())
|
||||||
|
|
||||||
|
logger.debug("Weather features merge completed",
|
||||||
|
final_rows=len(merged),
|
||||||
|
final_columns=len(merged.columns))
|
||||||
return merged
|
return merged
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Error merging weather data", error=str(e))
|
logger.warning("Error merging weather data", error=str(e), exc_info=True)
|
||||||
# Add default weather columns if merge fails
|
# Add default weather columns if merge fails
|
||||||
for feature, default_value in weather_defaults.items():
|
for feature, default_value in weather_defaults.items():
|
||||||
daily_sales[feature] = default_value
|
daily_sales[feature] = default_value
|
||||||
|
logger.debug("Default weather columns added after merge failure",
|
||||||
|
features_added=list(weather_defaults.keys()))
|
||||||
return daily_sales
|
return daily_sales
|
||||||
|
|
||||||
|
|
||||||
@@ -564,28 +800,43 @@ class EnhancedBakeryDataProcessor:
|
|||||||
daily_sales: pd.DataFrame,
|
daily_sales: pd.DataFrame,
|
||||||
traffic_data: pd.DataFrame) -> pd.DataFrame:
|
traffic_data: pd.DataFrame) -> pd.DataFrame:
|
||||||
"""Merge traffic features with enhanced Madrid-specific handling"""
|
"""Merge traffic features with enhanced Madrid-specific handling"""
|
||||||
|
logger.debug("Starting traffic features merge",
|
||||||
|
daily_sales_records=len(daily_sales),
|
||||||
|
traffic_data_records=len(traffic_data) if not traffic_data.empty else 0,
|
||||||
|
traffic_columns=list(traffic_data.columns) if not traffic_data.empty else [])
|
||||||
|
|
||||||
if traffic_data.empty:
|
if traffic_data.empty:
|
||||||
|
logger.debug("Traffic data is empty, adding default column")
|
||||||
# Add default traffic column
|
# Add default traffic column
|
||||||
daily_sales['traffic_volume'] = 100.0 # Neutral traffic level
|
daily_sales['traffic_volume'] = 100.0 # Neutral traffic level
|
||||||
|
logger.debug("Default traffic column added",
|
||||||
|
default_value=100.0)
|
||||||
return daily_sales
|
return daily_sales
|
||||||
|
|
||||||
try:
|
try:
|
||||||
traffic_clean = traffic_data.copy()
|
traffic_clean = traffic_data.copy()
|
||||||
|
logger.debug("Traffic data copied",
|
||||||
|
records=len(traffic_clean),
|
||||||
|
columns=list(traffic_clean.columns))
|
||||||
|
|
||||||
# Standardize date column
|
# Standardize date column
|
||||||
if 'date' not in traffic_clean.columns and 'ds' in traffic_clean.columns:
|
if 'date' not in traffic_clean.columns and 'ds' in traffic_clean.columns:
|
||||||
|
logger.debug("Renaming ds column to date")
|
||||||
traffic_clean = traffic_clean.rename(columns={'ds': 'date'})
|
traffic_clean = traffic_clean.rename(columns={'ds': 'date'})
|
||||||
|
|
||||||
# CRITICAL FIX: Ensure both DataFrames have compatible datetime formats
|
# CRITICAL FIX: Ensure both DataFrames have compatible datetime formats
|
||||||
|
logger.debug("Converting traffic data date column to datetime")
|
||||||
traffic_clean['date'] = pd.to_datetime(traffic_clean['date'])
|
traffic_clean['date'] = pd.to_datetime(traffic_clean['date'])
|
||||||
|
logger.debug("Converting daily sales date column to datetime")
|
||||||
daily_sales['date'] = pd.to_datetime(daily_sales['date'])
|
daily_sales['date'] = pd.to_datetime(daily_sales['date'])
|
||||||
|
|
||||||
# NEW FIX: Normalize both to timezone-naive datetime for merge compatibility
|
# NEW FIX: Normalize both to timezone-naive datetime for merge compatibility
|
||||||
if traffic_clean['date'].dt.tz is not None:
|
if traffic_clean['date'].dt.tz is not None:
|
||||||
|
logger.debug("Removing timezone from traffic data")
|
||||||
traffic_clean['date'] = traffic_clean['date'].dt.tz_convert('UTC').dt.tz_localize(None)
|
traffic_clean['date'] = traffic_clean['date'].dt.tz_convert('UTC').dt.tz_localize(None)
|
||||||
|
|
||||||
if daily_sales['date'].dt.tz is not None:
|
if daily_sales['date'].dt.tz is not None:
|
||||||
|
logger.debug("Removing timezone from daily sales data")
|
||||||
daily_sales['date'] = daily_sales['date'].dt.tz_convert('UTC').dt.tz_localize(None)
|
daily_sales['date'] = daily_sales['date'].dt.tz_convert('UTC').dt.tz_localize(None)
|
||||||
|
|
||||||
# Map traffic columns to standard names
|
# Map traffic columns to standard names
|
||||||
@@ -597,14 +848,24 @@ class EnhancedBakeryDataProcessor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
traffic_features = ['date']
|
traffic_features = ['date']
|
||||||
|
logger.debug("Mapping traffic columns",
|
||||||
|
mapping_attempts=list(traffic_mapping.keys()))
|
||||||
|
|
||||||
for standard_name, possible_names in traffic_mapping.items():
|
for standard_name, possible_names in traffic_mapping.items():
|
||||||
for possible_name in possible_names:
|
for possible_name in possible_names:
|
||||||
if possible_name in traffic_clean.columns:
|
if possible_name in traffic_clean.columns:
|
||||||
|
logger.debug("Processing traffic column",
|
||||||
|
standard_name=standard_name,
|
||||||
|
possible_name=possible_name,
|
||||||
|
records=len(traffic_clean))
|
||||||
|
|
||||||
# Extract numeric values using robust helper function
|
# Extract numeric values using robust helper function
|
||||||
try:
|
try:
|
||||||
# Check if column contains dict-like objects
|
# Check if column contains dict-like objects
|
||||||
|
logger.debug("Checking for dict objects in traffic column")
|
||||||
has_dicts = traffic_clean[possible_name].apply(lambda x: isinstance(x, dict)).any()
|
has_dicts = traffic_clean[possible_name].apply(lambda x: isinstance(x, dict)).any()
|
||||||
|
logger.debug("Dict object check completed",
|
||||||
|
has_dicts=has_dicts)
|
||||||
|
|
||||||
if has_dicts:
|
if has_dicts:
|
||||||
logger.warning(f"Traffic column {possible_name} contains dict objects, extracting numeric values")
|
logger.warning(f"Traffic column {possible_name} contains dict objects, extracting numeric values")
|
||||||
@@ -612,9 +873,14 @@ class EnhancedBakeryDataProcessor:
|
|||||||
traffic_clean[standard_name] = traffic_clean[possible_name].apply(
|
traffic_clean[standard_name] = traffic_clean[possible_name].apply(
|
||||||
self._extract_numeric_from_dict
|
self._extract_numeric_from_dict
|
||||||
)
|
)
|
||||||
|
logger.debug("Dict extraction completed for traffic column",
|
||||||
|
extracted_column=standard_name,
|
||||||
|
extracted_count=traffic_clean[standard_name].notna().sum())
|
||||||
else:
|
else:
|
||||||
# Direct numeric conversion for simple values
|
# Direct numeric conversion for simple values
|
||||||
|
logger.debug("Performing direct numeric conversion")
|
||||||
traffic_clean[standard_name] = pd.to_numeric(traffic_clean[possible_name], errors='coerce')
|
traffic_clean[standard_name] = pd.to_numeric(traffic_clean[possible_name], errors='coerce')
|
||||||
|
logger.debug("Direct numeric conversion completed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error converting traffic column {possible_name}: {e}")
|
logger.warning(f"Error converting traffic column {possible_name}: {e}")
|
||||||
# Fallback: try to extract from each value
|
# Fallback: try to extract from each value
|
||||||
@@ -622,14 +888,28 @@ class EnhancedBakeryDataProcessor:
|
|||||||
self._extract_numeric_from_dict
|
self._extract_numeric_from_dict
|
||||||
)
|
)
|
||||||
traffic_features.append(standard_name)
|
traffic_features.append(standard_name)
|
||||||
|
logger.debug("Added traffic feature to list",
|
||||||
|
feature=standard_name)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Keep only the features we found
|
# Keep only the features we found
|
||||||
|
logger.debug("Selecting traffic features",
|
||||||
|
selected_features=traffic_features)
|
||||||
traffic_clean = traffic_clean[traffic_features].copy()
|
traffic_clean = traffic_clean[traffic_features].copy()
|
||||||
|
|
||||||
# Merge with sales data
|
# Merge with sales data
|
||||||
|
logger.debug("Starting traffic merge operation",
|
||||||
|
daily_sales_rows=len(daily_sales),
|
||||||
|
traffic_rows=len(traffic_clean),
|
||||||
|
date_range_sales=(daily_sales['date'].min(), daily_sales['date'].max()) if len(daily_sales) > 0 else None,
|
||||||
|
date_range_traffic=(traffic_clean['date'].min(), traffic_clean['date'].max()) if len(traffic_clean) > 0 else None)
|
||||||
|
|
||||||
merged = daily_sales.merge(traffic_clean, on='date', how='left')
|
merged = daily_sales.merge(traffic_clean, on='date', how='left')
|
||||||
|
|
||||||
|
logger.debug("Traffic merge completed",
|
||||||
|
merged_rows=len(merged),
|
||||||
|
merge_type='left')
|
||||||
|
|
||||||
# Fill missing traffic values with reasonable defaults
|
# Fill missing traffic values with reasonable defaults
|
||||||
traffic_defaults = {
|
traffic_defaults = {
|
||||||
'traffic_volume': 100.0,
|
'traffic_volume': 100.0,
|
||||||
@@ -638,18 +918,31 @@ class EnhancedBakeryDataProcessor:
|
|||||||
'average_speed': 30.0 # km/h typical for Madrid
|
'average_speed': 30.0 # km/h typical for Madrid
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.debug("Filling missing traffic values",
|
||||||
|
features_to_fill=list(traffic_defaults.keys()))
|
||||||
for feature, default_value in traffic_defaults.items():
|
for feature, default_value in traffic_defaults.items():
|
||||||
if feature in merged.columns:
|
if feature in merged.columns:
|
||||||
|
logger.debug("Processing traffic feature for NaN fill",
|
||||||
|
feature=feature,
|
||||||
|
nan_count=merged[feature].isna().sum())
|
||||||
# Ensure the column is numeric before filling
|
# Ensure the column is numeric before filling
|
||||||
merged[feature] = pd.to_numeric(merged[feature], errors='coerce')
|
merged[feature] = pd.to_numeric(merged[feature], errors='coerce')
|
||||||
merged[feature] = merged[feature].fillna(default_value)
|
merged[feature] = merged[feature].fillna(default_value)
|
||||||
|
logger.debug("NaN fill completed for traffic feature",
|
||||||
|
feature=feature,
|
||||||
|
final_nan_count=merged[feature].isna().sum())
|
||||||
|
|
||||||
|
logger.debug("Traffic features merge completed",
|
||||||
|
final_rows=len(merged),
|
||||||
|
final_columns=len(merged.columns))
|
||||||
return merged
|
return merged
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Error merging traffic data", error=str(e))
|
logger.warning("Error merging traffic data", error=str(e), exc_info=True)
|
||||||
# Add default traffic column if merge fails
|
# Add default traffic column if merge fails
|
||||||
daily_sales['traffic_volume'] = 100.0
|
daily_sales['traffic_volume'] = 100.0
|
||||||
|
logger.debug("Default traffic column added after merge failure",
|
||||||
|
default_value=100.0)
|
||||||
return daily_sales
|
return daily_sales
|
||||||
|
|
||||||
def _engineer_features(self, df: pd.DataFrame) -> pd.DataFrame:
|
def _engineer_features(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||||
@@ -774,12 +1067,26 @@ class EnhancedBakeryDataProcessor:
|
|||||||
"""
|
"""
|
||||||
df = df.copy()
|
df = df.copy()
|
||||||
|
|
||||||
logger.info("Adding advanced features (lagged, rolling, cyclical, trends)")
|
logger.info("Adding advanced features (lagged, rolling, cyclical, trends)",
|
||||||
|
input_rows=len(df),
|
||||||
|
input_columns=len(df.columns))
|
||||||
|
|
||||||
|
# Log column dtypes to identify potential issues
|
||||||
|
logger.debug("Input dataframe dtypes",
|
||||||
|
dtypes={col: str(dtype) for col, dtype in df.dtypes.items()},
|
||||||
|
date_column_exists='date' in df.columns)
|
||||||
|
|
||||||
# Reset feature engineer to clear previous features
|
# Reset feature engineer to clear previous features
|
||||||
|
logger.debug("Initializing AdvancedFeatureEngineer")
|
||||||
self.feature_engineer = AdvancedFeatureEngineer()
|
self.feature_engineer = AdvancedFeatureEngineer()
|
||||||
|
|
||||||
# Create all advanced features at once
|
# Create all advanced features at once
|
||||||
|
logger.debug("Starting creation of advanced features",
|
||||||
|
include_lags=True,
|
||||||
|
include_rolling=True,
|
||||||
|
include_interactions=True,
|
||||||
|
include_cyclical=True)
|
||||||
|
|
||||||
df = self.feature_engineer.create_all_features(
|
df = self.feature_engineer.create_all_features(
|
||||||
df,
|
df,
|
||||||
date_column='date',
|
date_column='date',
|
||||||
@@ -789,8 +1096,16 @@ class EnhancedBakeryDataProcessor:
|
|||||||
include_cyclical=True
|
include_cyclical=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.debug("Advanced features creation completed",
|
||||||
|
output_rows=len(df),
|
||||||
|
output_columns=len(df.columns))
|
||||||
|
|
||||||
# Fill NA values from lagged and rolling features
|
# Fill NA values from lagged and rolling features
|
||||||
|
logger.debug("Starting NA value filling",
|
||||||
|
na_counts={col: df[col].isna().sum() for col in df.columns if df[col].isna().any()})
|
||||||
df = self.feature_engineer.fill_na_values(df, strategy='forward_backward')
|
df = self.feature_engineer.fill_na_values(df, strategy='forward_backward')
|
||||||
|
logger.debug("NA value filling completed",
|
||||||
|
remaining_na_counts={col: df[col].isna().sum() for col in df.columns if df[col].isna().any()})
|
||||||
|
|
||||||
# Store created feature columns for later reference
|
# Store created feature columns for later reference
|
||||||
created_features = self.feature_engineer.get_feature_columns()
|
created_features = self.feature_engineer.get_feature_columns()
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ class EnhancedBakeryMLTrainer:
|
|||||||
tenant_id: Tenant identifier
|
tenant_id: Tenant identifier
|
||||||
training_dataset: Prepared training dataset with aligned dates
|
training_dataset: Prepared training dataset with aligned dates
|
||||||
job_id: Training job identifier
|
job_id: Training job identifier
|
||||||
|
session: Database session to use (if None, creates one)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with training results for each product
|
Dictionary with training results for each product
|
||||||
@@ -89,68 +90,20 @@ class EnhancedBakeryMLTrainer:
|
|||||||
tenant_id=tenant_id)
|
tenant_id=tenant_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get database session and repositories
|
# Use provided session or create new one to prevent nested sessions and deadlocks
|
||||||
async with self.database_manager.get_session() as db_session:
|
should_create_session = session is None
|
||||||
repos = await self._get_repositories(db_session)
|
db_session = session if session is not None else None
|
||||||
|
|
||||||
# Convert sales data to DataFrame
|
# Use the provided session or create a new one if needed
|
||||||
sales_df = pd.DataFrame(training_dataset.sales_data)
|
if should_create_session:
|
||||||
weather_df = pd.DataFrame(training_dataset.weather_data)
|
async with self.database_manager.get_session() as db_session:
|
||||||
traffic_df = pd.DataFrame(training_dataset.traffic_data)
|
return await self._execute_training_pipeline(
|
||||||
|
tenant_id, training_dataset, job_id, db_session
|
||||||
# Validate input data
|
|
||||||
await self._validate_input_data(sales_df, tenant_id)
|
|
||||||
|
|
||||||
# Get unique products from the sales data
|
|
||||||
products = sales_df['inventory_product_id'].unique().tolist()
|
|
||||||
|
|
||||||
# Debug: Log sales data details to understand why only one product is found
|
|
||||||
total_sales_records = len(sales_df)
|
|
||||||
sales_by_product = sales_df.groupby('inventory_product_id').size().to_dict()
|
|
||||||
|
|
||||||
logger.info("Enhanced training pipeline - Sales data analysis",
|
|
||||||
total_sales_records=total_sales_records,
|
|
||||||
products_count=len(products),
|
|
||||||
products=products,
|
|
||||||
sales_by_product=sales_by_product)
|
|
||||||
|
|
||||||
if len(products) == 1:
|
|
||||||
logger.warning("Only ONE product found in sales data - this may indicate a data fetching issue",
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
single_product_id=products[0],
|
|
||||||
total_sales_records=total_sales_records)
|
|
||||||
elif len(products) == 0:
|
|
||||||
raise ValueError("No products found in sales data")
|
|
||||||
else:
|
|
||||||
logger.info("Multiple products detected for training",
|
|
||||||
products_count=len(products))
|
|
||||||
|
|
||||||
# Event 1: Training Started (0%) - update with actual product count AND time estimates
|
|
||||||
# Calculate accurate time estimates now that we know the actual product count
|
|
||||||
from app.utils.time_estimation import (
|
|
||||||
calculate_initial_estimate,
|
|
||||||
calculate_estimated_completion_time,
|
|
||||||
get_historical_average_estimate
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to get historical average for more accurate estimates
|
|
||||||
try:
|
|
||||||
historical_avg = await get_historical_average_estimate(
|
|
||||||
db_session,
|
|
||||||
tenant_id
|
|
||||||
)
|
)
|
||||||
avg_time_per_product = historical_avg if historical_avg else 60.0
|
else:
|
||||||
logger.info("Using historical average for time estimation",
|
# Use the provided session (no context manager needed since caller manages it)
|
||||||
avg_time_per_product=avg_time_per_product,
|
return await self._execute_training_pipeline(
|
||||||
has_historical_data=historical_avg is not None)
|
tenant_id, training_dataset, job_id, session
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Could not get historical average, using default",
|
|
||||||
error=str(e))
|
|
||||||
avg_time_per_product = 60.0
|
|
||||||
|
|
||||||
estimated_duration_minutes = calculate_initial_estimate(
|
|
||||||
total_products=len(products),
|
|
||||||
avg_training_time_per_product=avg_time_per_product
|
|
||||||
)
|
)
|
||||||
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
|
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
|
||||||
|
|
||||||
@@ -177,7 +130,7 @@ class EnhancedBakeryMLTrainer:
|
|||||||
# Process data for each product using enhanced processor
|
# Process data for each product using enhanced processor
|
||||||
logger.info("Processing data using enhanced processor")
|
logger.info("Processing data using enhanced processor")
|
||||||
processed_data = await self._process_all_products_enhanced(
|
processed_data = await self._process_all_products_enhanced(
|
||||||
sales_df, weather_df, traffic_df, products, tenant_id, job_id
|
sales_df, weather_df, traffic_df, products, tenant_id, job_id, session
|
||||||
)
|
)
|
||||||
|
|
||||||
# Categorize all products for category-specific forecasting
|
# Categorize all products for category-specific forecasting
|
||||||
@@ -302,11 +255,219 @@ class EnhancedBakeryMLTrainer:
|
|||||||
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def _execute_training_pipeline(self, tenant_id: str, training_dataset: TrainingDataSet,
|
||||||
|
job_id: str, session) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Execute the training pipeline with the given session.
|
||||||
|
This is extracted to avoid code duplication when handling provided vs. created sessions.
|
||||||
|
"""
|
||||||
|
# Get repositories with the session
|
||||||
|
repos = await self._get_repositories(session)
|
||||||
|
|
||||||
|
# Convert sales data to DataFrame
|
||||||
|
sales_df = pd.DataFrame(training_dataset.sales_data)
|
||||||
|
weather_df = pd.DataFrame(training_dataset.weather_data)
|
||||||
|
traffic_df = pd.DataFrame(training_dataset.traffic_data)
|
||||||
|
|
||||||
|
# Validate input data
|
||||||
|
await self._validate_input_data(sales_df, tenant_id)
|
||||||
|
|
||||||
|
# Get unique products from the sales data
|
||||||
|
products = sales_df['inventory_product_id'].unique().tolist()
|
||||||
|
|
||||||
|
# Debug: Log sales data details to understand why only one product is found
|
||||||
|
total_sales_records = len(sales_df)
|
||||||
|
sales_by_product = sales_df.groupby('inventory_product_id').size().to_dict()
|
||||||
|
|
||||||
|
logger.info("Enhanced training pipeline - Sales data analysis",
|
||||||
|
total_sales_records=total_sales_records,
|
||||||
|
products_count=len(products),
|
||||||
|
products=products,
|
||||||
|
sales_by_product=sales_by_product)
|
||||||
|
|
||||||
|
if len(products) == 1:
|
||||||
|
logger.warning("Only ONE product found in sales data - this may indicate a data fetching issue",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
single_product_id=products[0],
|
||||||
|
total_sales_records=total_sales_records)
|
||||||
|
elif len(products) == 0:
|
||||||
|
raise ValueError("No products found in sales data")
|
||||||
|
else:
|
||||||
|
logger.info("Multiple products detected for training",
|
||||||
|
products_count=len(products))
|
||||||
|
|
||||||
|
# Event 1: Training Started (0%) - update with actual product count AND time estimates
|
||||||
|
# Calculate accurate time estimates now that we know the actual product count
|
||||||
|
from app.utils.time_estimation import (
|
||||||
|
calculate_initial_estimate,
|
||||||
|
calculate_estimated_completion_time,
|
||||||
|
get_historical_average_estimate
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to get historical average for more accurate estimates
|
||||||
|
try:
|
||||||
|
historical_avg = await get_historical_average_estimate(
|
||||||
|
session,
|
||||||
|
tenant_id
|
||||||
|
)
|
||||||
|
avg_time_per_product = historical_avg if historical_avg else 60.0
|
||||||
|
logger.info("Using historical average for time estimation",
|
||||||
|
avg_time_per_product=avg_time_per_product,
|
||||||
|
has_historical_data=historical_avg is not None)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Could not get historical average, using default",
|
||||||
|
error=str(e))
|
||||||
|
avg_time_per_product = 60.0
|
||||||
|
|
||||||
|
estimated_duration_minutes = calculate_initial_estimate(
|
||||||
|
total_products=len(products),
|
||||||
|
avg_training_time_per_product=avg_time_per_product
|
||||||
|
)
|
||||||
|
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
|
||||||
|
|
||||||
|
# Note: Initial event was already published by API endpoint with estimated product count,
|
||||||
|
# this updates with real count and recalculated time estimates based on actual data
|
||||||
|
await publish_training_started(
|
||||||
|
job_id=job_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
total_products=len(products),
|
||||||
|
estimated_duration_minutes=estimated_duration_minutes,
|
||||||
|
estimated_completion_time=estimated_completion_time.isoformat()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create initial training log entry
|
||||||
|
await repos['training_log'].update_log_progress(
|
||||||
|
job_id, 5, "data_processing", "running"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ✅ FIX: Flush the session to ensure the update is committed before proceeding
|
||||||
|
# This prevents deadlocks when training methods need to acquire locks
|
||||||
|
await session.flush()
|
||||||
|
logger.debug("Flushed session after initial progress update")
|
||||||
|
|
||||||
|
# Process data for each product using enhanced processor
|
||||||
|
logger.info("Processing data using enhanced processor")
|
||||||
|
processed_data = await self._process_all_products_enhanced(
|
||||||
|
sales_df, weather_df, traffic_df, products, tenant_id, job_id, session
|
||||||
|
)
|
||||||
|
|
||||||
|
# Categorize all products for category-specific forecasting
|
||||||
|
logger.info("Categorizing products for optimized forecasting")
|
||||||
|
product_categories = await self._categorize_all_products(
|
||||||
|
sales_df, processed_data
|
||||||
|
)
|
||||||
|
logger.info("Product categorization complete",
|
||||||
|
total_products=len(product_categories),
|
||||||
|
categories_breakdown={cat.value: sum(1 for c in product_categories.values() if c == cat)
|
||||||
|
for cat in set(product_categories.values())})
|
||||||
|
|
||||||
|
# Event 2: Data Analysis (20%)
|
||||||
|
# Recalculate time remaining based on elapsed time
|
||||||
|
start_time = await repos['training_log'].get_start_time(job_id)
|
||||||
|
elapsed_seconds = 0
|
||||||
|
if start_time:
|
||||||
|
elapsed_seconds = int((datetime.now(timezone.utc) - start_time).total_seconds())
|
||||||
|
|
||||||
|
# Estimate remaining time: we've done ~20% of work (data analysis)
|
||||||
|
# Remaining 80% includes training all products
|
||||||
|
products_to_train = len(processed_data)
|
||||||
|
estimated_remaining_seconds = int(products_to_train * avg_time_per_product)
|
||||||
|
|
||||||
|
# Recalculate estimated completion time
|
||||||
|
estimated_completion_time_data_analysis = calculate_estimated_completion_time(
|
||||||
|
estimated_remaining_seconds / 60
|
||||||
|
)
|
||||||
|
|
||||||
|
await publish_data_analysis(
|
||||||
|
job_id,
|
||||||
|
tenant_id,
|
||||||
|
f"Data analysis completed for {len(processed_data)} products",
|
||||||
|
estimated_time_remaining_seconds=estimated_remaining_seconds,
|
||||||
|
estimated_completion_time=estimated_completion_time_data_analysis.isoformat()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train models for each processed product with progress aggregation
|
||||||
|
logger.info("Training models with repository integration and progress aggregation")
|
||||||
|
|
||||||
|
# Create progress tracker for parallel product training (20-80%)
|
||||||
|
progress_tracker = ParallelProductProgressTracker(
|
||||||
|
job_id=job_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
total_products=len(processed_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train all models in parallel (without DB writes to avoid session conflicts)
|
||||||
|
# ✅ FIX: Pass session to prevent nested session issues and deadlocks
|
||||||
|
training_results = await self._train_all_models_enhanced(
|
||||||
|
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories, session
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write all training results to database sequentially (after parallel training completes)
|
||||||
|
logger.info("Writing training results to database sequentially")
|
||||||
|
training_results = await self._write_training_results_to_database(
|
||||||
|
tenant_id, job_id, training_results, repos
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate overall training summary with enhanced metrics
|
||||||
|
summary = await self._calculate_enhanced_training_summary(
|
||||||
|
training_results, repos, tenant_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate successful and failed trainings
|
||||||
|
successful_trainings = len([r for r in training_results.values() if r.get('status') == 'success'])
|
||||||
|
failed_trainings = len([r for r in training_results.values() if r.get('status') == 'error'])
|
||||||
|
total_duration = sum([r.get('training_time_seconds', 0) for r in training_results.values()])
|
||||||
|
|
||||||
|
# Event 4: Training Completed (100%)
|
||||||
|
await publish_training_completed(
|
||||||
|
job_id,
|
||||||
|
tenant_id,
|
||||||
|
successful_trainings,
|
||||||
|
failed_trainings,
|
||||||
|
total_duration
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create comprehensive result with repository data
|
||||||
|
result = {
|
||||||
|
"job_id": job_id,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"status": "completed",
|
||||||
|
"products_trained": len([r for r in training_results.values() if r.get('status') == 'success']),
|
||||||
|
"products_failed": len([r for r in training_results.values() if r.get('status') == 'error']),
|
||||||
|
"products_skipped": len([r for r in training_results.values() if r.get('status') == 'skipped']),
|
||||||
|
"total_products": len(products),
|
||||||
|
"training_results": training_results,
|
||||||
|
"enhanced_summary": summary,
|
||||||
|
"models_trained": summary.get('models_created', {}),
|
||||||
|
"data_info": {
|
||||||
|
"date_range": {
|
||||||
|
"start": training_dataset.date_range.start.isoformat(),
|
||||||
|
"end": training_dataset.date_range.end.isoformat(),
|
||||||
|
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
|
||||||
|
},
|
||||||
|
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
|
||||||
|
"constraints_applied": training_dataset.date_range.constraints
|
||||||
|
},
|
||||||
|
"repository_metadata": {
|
||||||
|
"total_records_created": summary.get('total_db_records', 0),
|
||||||
|
"performance_metrics_stored": summary.get('performance_metrics_created', 0),
|
||||||
|
"artifacts_created": summary.get('artifacts_created', 0)
|
||||||
|
},
|
||||||
|
"completed_at": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("Enhanced ML training pipeline completed successfully",
|
||||||
|
job_id=job_id,
|
||||||
|
models_created=len([r for r in training_results.values() if r.get('status') == 'success']))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def train_single_product_model(self,
|
async def train_single_product_model(self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
inventory_product_id: str,
|
inventory_product_id: str,
|
||||||
training_data: pd.DataFrame,
|
training_data: pd.DataFrame,
|
||||||
job_id: str = None) -> Dict[str, Any]:
|
job_id: str = None,
|
||||||
|
session=None) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Train a model for a single product using repository pattern.
|
Train a model for a single product using repository pattern.
|
||||||
|
|
||||||
@@ -315,6 +476,7 @@ class EnhancedBakeryMLTrainer:
|
|||||||
inventory_product_id: Specific inventory product to train
|
inventory_product_id: Specific inventory product to train
|
||||||
training_data: Prepared training DataFrame for the product
|
training_data: Prepared training DataFrame for the product
|
||||||
job_id: Training job identifier (optional)
|
job_id: Training job identifier (optional)
|
||||||
|
session: Database session to use (if None, creates one)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with model training results
|
Dictionary with model training results
|
||||||
@@ -329,9 +491,14 @@ class EnhancedBakeryMLTrainer:
|
|||||||
data_points=len(training_data))
|
data_points=len(training_data))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get database session and repositories
|
# Use provided session or create new one to prevent nested sessions and deadlocks
|
||||||
async with self.database_manager.get_session() as db_session:
|
should_create_session = session is None
|
||||||
repos = await self._get_repositories(db_session)
|
db_session = session if session is not None else None
|
||||||
|
|
||||||
|
if should_create_session:
|
||||||
|
# Only create a session if one wasn't provided
|
||||||
|
async with self.database_manager.get_session() as db_session:
|
||||||
|
repos = await self._get_repositories(db_session)
|
||||||
|
|
||||||
# Validate input data
|
# Validate input data
|
||||||
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
|
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||||
@@ -367,7 +534,8 @@ class EnhancedBakeryMLTrainer:
|
|||||||
product_data=training_data,
|
product_data=training_data,
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
repos=repos,
|
repos=repos,
|
||||||
progress_tracker=progress_tracker
|
progress_tracker=progress_tracker,
|
||||||
|
session=db_session # Pass the session to prevent nested sessions
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Single product training completed",
|
logger.info("Single product training completed",
|
||||||
@@ -391,7 +559,7 @@ class EnhancedBakeryMLTrainer:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Return appropriate result format
|
# Return appropriate result format
|
||||||
return {
|
result_dict = {
|
||||||
"job_id": job_id,
|
"job_id": job_id,
|
||||||
"tenant_id": tenant_id,
|
"tenant_id": tenant_id,
|
||||||
"inventory_product_id": inventory_product_id,
|
"inventory_product_id": inventory_product_id,
|
||||||
@@ -403,6 +571,8 @@ class EnhancedBakeryMLTrainer:
|
|||||||
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
|
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return result_dict
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Single product model training failed",
|
logger.error("Single product model training failed",
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
@@ -452,7 +622,8 @@ class EnhancedBakeryMLTrainer:
|
|||||||
traffic_df: pd.DataFrame,
|
traffic_df: pd.DataFrame,
|
||||||
products: List[str],
|
products: List[str],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
job_id: str) -> Dict[str, pd.DataFrame]:
|
job_id: str,
|
||||||
|
session=None) -> Dict[str, pd.DataFrame]:
|
||||||
"""Process data for all products using enhanced processor with repository tracking"""
|
"""Process data for all products using enhanced processor with repository tracking"""
|
||||||
processed_data = {}
|
processed_data = {}
|
||||||
|
|
||||||
@@ -470,13 +641,15 @@ class EnhancedBakeryMLTrainer:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Use enhanced data processor with repository tracking
|
# Use enhanced data processor with repository tracking
|
||||||
|
# Pass the session to prevent nested session issues
|
||||||
processed_product_data = await self.enhanced_data_processor.prepare_training_data(
|
processed_product_data = await self.enhanced_data_processor.prepare_training_data(
|
||||||
sales_data=product_sales,
|
sales_data=product_sales,
|
||||||
weather_data=weather_df,
|
weather_data=weather_df,
|
||||||
traffic_data=traffic_df,
|
traffic_data=traffic_df,
|
||||||
inventory_product_id=inventory_product_id,
|
inventory_product_id=inventory_product_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
job_id=job_id
|
job_id=job_id,
|
||||||
|
session=session # Pass the session to avoid creating new ones
|
||||||
)
|
)
|
||||||
|
|
||||||
processed_data[inventory_product_id] = processed_product_data
|
processed_data[inventory_product_id] = processed_product_data
|
||||||
|
|||||||
@@ -244,7 +244,8 @@ class EnhancedTrainingService:
|
|||||||
training_results = await self.trainer.train_tenant_models(
|
training_results = await self.trainer.train_tenant_models(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
training_dataset=training_dataset,
|
training_dataset=training_dataset,
|
||||||
job_id=job_id
|
job_id=job_id,
|
||||||
|
session=session # Pass the main session to avoid nested sessions
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.training_log_repo.update_log_progress(
|
await self.training_log_repo.update_log_progress(
|
||||||
|
|||||||
Reference in New Issue
Block a user