Fix deadlock issues in training

This commit is contained in:
Urtzi Alfaro
2025-11-05 18:47:20 +01:00
parent fd0a96e254
commit 74215d3e85
3 changed files with 620 additions and 131 deletions

View File

@@ -139,70 +139,154 @@ class EnhancedBakeryDataProcessor:
tenant_id=tenant_id, tenant_id=tenant_id,
job_id=job_id) job_id=job_id)
# Get database session and repositories # Use provided session if available, otherwise create one
async with self.database_manager.get_session() as db_session: if session is None:
repos = await self._get_repositories(db_session) logger.debug("Creating new session for data preparation",
inventory_product_id=inventory_product_id)
async with self.database_manager.get_session() as db_session:
repos = await self._get_repositories(db_session)
# Log data preparation start if we have tracking info
if job_id and tenant_id:
logger.debug("About to update training log progress",
inventory_product_id=inventory_product_id,
job_id=job_id)
await repos['training_log'].update_log_progress(
job_id, 15, f"preparing_data_{inventory_product_id}", "running"
)
logger.debug("Updated training log progress",
inventory_product_id=inventory_product_id,
job_id=job_id)
# Commit the created session
await db_session.commit()
logger.debug("Committed session after data preparation progress update",
inventory_product_id=inventory_product_id)
else:
logger.debug("Using provided session for data preparation",
inventory_product_id=inventory_product_id)
# Use the provided session
repos = await self._get_repositories(session)
# Log data preparation start if we have tracking info # Log data preparation start if we have tracking info
if job_id and tenant_id: if job_id and tenant_id:
logger.debug("About to update training log progress with provided session",
inventory_product_id=inventory_product_id,
job_id=job_id)
await repos['training_log'].update_log_progress( await repos['training_log'].update_log_progress(
job_id, 15, f"preparing_data_{inventory_product_id}", "running" job_id, 15, f"preparing_data_{inventory_product_id}", "running"
) )
# ✅ FIX: Commit the session to prevent deadlock with parent trainer session logger.debug("Updated training log progress with provided session",
# The trainer has its own session, so we need to commit this update inventory_product_id=inventory_product_id,
await db_session.commit() job_id=job_id)
logger.debug("Committed session after data preparation progress update", # Don't commit the provided session as the caller manages it
logger.debug("Updated progress with provided session",
inventory_product_id=inventory_product_id) inventory_product_id=inventory_product_id)
# Step 1: Convert and validate sales data logger.debug("Starting Step 1: Convert and validate sales data",
sales_clean = await self._process_sales_data(sales_data, inventory_product_id) inventory_product_id=inventory_product_id)
# Step 1: Convert and validate sales data
sales_clean = await self._process_sales_data(sales_data, inventory_product_id)
logger.debug("Step 1 completed: Convert and validate sales data",
inventory_product_id=inventory_product_id,
sales_records=len(sales_clean))
# FIX: Ensure timezone awareness before any operations logger.debug("Starting Step 2: Ensure timezone awareness",
sales_clean = self._ensure_timezone_aware(sales_clean) inventory_product_id=inventory_product_id)
weather_data = self._ensure_timezone_aware(weather_data) if not weather_data.empty else weather_data # FIX: Ensure timezone awareness before any operations
traffic_data = self._ensure_timezone_aware(traffic_data) if not traffic_data.empty else traffic_data sales_clean = self._ensure_timezone_aware(sales_clean)
weather_data = self._ensure_timezone_aware(weather_data) if not weather_data.empty else weather_data
traffic_data = self._ensure_timezone_aware(traffic_data) if not traffic_data.empty else traffic_data
logger.debug("Step 2 completed: Ensure timezone awareness",
inventory_product_id=inventory_product_id,
weather_records=len(weather_data) if not weather_data.empty else 0,
traffic_records=len(traffic_data) if not traffic_data.empty else 0)
# Step 2: Apply date alignment if we have date constraints logger.debug("Starting Step 3: Apply date alignment",
sales_clean = await self._apply_date_alignment(sales_clean, weather_data, traffic_data) inventory_product_id=inventory_product_id)
# Step 2: Apply date alignment if we have date constraints
sales_clean = await self._apply_date_alignment(sales_clean, weather_data, traffic_data)
logger.debug("Step 3 completed: Apply date alignment",
inventory_product_id=inventory_product_id,
sales_records=len(sales_clean))
# Step 3: Aggregate to daily level logger.debug("Starting Step 4: Aggregate to daily level",
daily_sales = await self._aggregate_daily_sales(sales_clean) inventory_product_id=inventory_product_id)
# Step 3: Aggregate to daily level
daily_sales = await self._aggregate_daily_sales(sales_clean)
logger.debug("Step 4 completed: Aggregate to daily level",
inventory_product_id=inventory_product_id,
daily_records=len(daily_sales))
# Step 4: Add temporal features logger.debug("Starting Step 5: Add temporal features",
daily_sales = self._add_temporal_features(daily_sales) inventory_product_id=inventory_product_id)
# Step 4: Add temporal features
daily_sales = self._add_temporal_features(daily_sales)
logger.debug("Step 5 completed: Add temporal features",
inventory_product_id=inventory_product_id,
features_added=True)
# Step 5: Merge external data sources logger.debug("Starting Step 6: Merge external data sources",
daily_sales = self._merge_weather_features(daily_sales, weather_data) inventory_product_id=inventory_product_id)
daily_sales = self._merge_traffic_features(daily_sales, traffic_data) # Step 5: Merge external data sources
daily_sales = self._merge_weather_features(daily_sales, weather_data)
daily_sales = self._merge_traffic_features(daily_sales, traffic_data)
logger.debug("Step 6 completed: Merge external data sources",
inventory_product_id=inventory_product_id,
merged_successfully=True)
# Step 6: Engineer basic features logger.debug("Starting Step 7: Engineer basic features",
daily_sales = self._engineer_features(daily_sales) inventory_product_id=inventory_product_id)
# Step 6: Engineer basic features
daily_sales = self._engineer_features(daily_sales)
logger.debug("Step 7 completed: Engineer basic features",
inventory_product_id=inventory_product_id,
feature_columns=len([col for col in daily_sales.columns if col not in ['date', 'quantity']]))
# Step 6b: Add advanced features (lagged, rolling, cyclical, interactions, trends) logger.debug("Starting Step 8: Add advanced features",
daily_sales = self._add_advanced_features(daily_sales) inventory_product_id=inventory_product_id)
# Step 6b: Add advanced features (lagged, rolling, cyclical, interactions, trends)
daily_sales = self._add_advanced_features(daily_sales)
logger.debug("Step 8 completed: Add advanced features",
inventory_product_id=inventory_product_id,
total_features=len(daily_sales.columns))
# Step 7: Handle missing values logger.debug("Starting Step 9: Handle missing values",
daily_sales = self._handle_missing_values(daily_sales) inventory_product_id=inventory_product_id)
# Step 7: Handle missing values
daily_sales = self._handle_missing_values(daily_sales)
logger.debug("Step 9 completed: Handle missing values",
inventory_product_id=inventory_product_id,
missing_values_handled=True)
# Step 8: Prepare for Prophet (rename columns and validate) logger.debug("Starting Step 10: Prepare for Prophet format",
prophet_data = self._prepare_prophet_format(daily_sales) inventory_product_id=inventory_product_id)
# Step 8: Prepare for Prophet (rename columns and validate)
prophet_data = self._prepare_prophet_format(daily_sales)
logger.debug("Step 10 completed: Prepare for Prophet format",
inventory_product_id=inventory_product_id,
prophet_records=len(prophet_data))
# Step 9: Store processing metadata if we have a tenant logger.debug("Starting Step 11: Store processing metadata",
if tenant_id: inventory_product_id=inventory_product_id)
await self._store_processing_metadata( # Step 9: Store processing metadata if we have a tenant
repos, tenant_id, inventory_product_id, prophet_data, job_id if tenant_id:
) await self._store_processing_metadata(
repos, tenant_id, inventory_product_id, prophet_data, job_id, session
)
logger.debug("Step 11 completed: Store processing metadata",
inventory_product_id=inventory_product_id)
logger.info("Enhanced training data prepared successfully", logger.info("Enhanced training data prepared successfully",
inventory_product_id=inventory_product_id, inventory_product_id=inventory_product_id,
data_points=len(prophet_data)) data_points=len(prophet_data))
return prophet_data return prophet_data
except Exception as e: except Exception as e:
logger.error("Error preparing enhanced training data", logger.error("Error preparing enhanced training data",
inventory_product_id=inventory_product_id, inventory_product_id=inventory_product_id,
error=str(e)) error=str(e),
exc_info=True)
raise raise
async def _store_processing_metadata(self, async def _store_processing_metadata(self,
@@ -210,7 +294,8 @@ class EnhancedBakeryDataProcessor:
tenant_id: str, tenant_id: str,
inventory_product_id: str, inventory_product_id: str,
processed_data: pd.DataFrame, processed_data: pd.DataFrame,
job_id: str = None): job_id: str = None,
session=None):
"""Store data processing metadata using repository""" """Store data processing metadata using repository"""
try: try:
# Create processing metadata # Create processing metadata
@@ -230,9 +315,12 @@ class EnhancedBakeryDataProcessor:
await repos['training_log'].update_log_progress( await repos['training_log'].update_log_progress(
job_id, 25, f"data_prepared_{inventory_product_id}", "running" job_id, 25, f"data_prepared_{inventory_product_id}", "running"
) )
# ✅ FIX: Commit after final progress update to prevent deadlock # If we have a session and it's not managed elsewhere, commit it
await db_session.commit() if session is not None:
logger.debug("Committed session after data preparation completion", # Don't commit here as the caller will manage the session
pass
logger.debug("Data preparation metadata stored",
inventory_product_id=inventory_product_id) inventory_product_id=inventory_product_id)
except Exception as e: except Exception as e:
@@ -358,69 +446,160 @@ class EnhancedBakeryDataProcessor:
async def _process_sales_data(self, sales_data: pd.DataFrame, inventory_product_id: str) -> pd.DataFrame: async def _process_sales_data(self, sales_data: pd.DataFrame, inventory_product_id: str) -> pd.DataFrame:
"""Process and clean sales data with enhanced validation""" """Process and clean sales data with enhanced validation"""
logger.debug("Starting sales data processing",
inventory_product_id=inventory_product_id,
total_records=len(sales_data),
columns=list(sales_data.columns))
sales_clean = sales_data.copy() sales_clean = sales_data.copy()
logger.debug("Checking for date column existence",
inventory_product_id=inventory_product_id)
# Ensure date column exists and is datetime # Ensure date column exists and is datetime
if 'date' not in sales_clean.columns: if 'date' not in sales_clean.columns:
logger.error("Sales data must have a 'date' column",
inventory_product_id=inventory_product_id,
available_columns=list(sales_data.columns))
raise ValueError("Sales data must have a 'date' column") raise ValueError("Sales data must have a 'date' column")
logger.debug("Converting date column to datetime",
inventory_product_id=inventory_product_id)
sales_clean['date'] = pd.to_datetime(sales_clean['date']) sales_clean['date'] = pd.to_datetime(sales_clean['date'])
logger.debug("Date conversion completed",
inventory_product_id=inventory_product_id)
# Handle different quantity column names # Handle different quantity column names
quantity_columns = ['quantity', 'quantity_sold', 'sales', 'units_sold'] quantity_columns = ['quantity', 'quantity_sold', 'sales', 'units_sold']
logger.debug("Looking for quantity column",
inventory_product_id=inventory_product_id,
quantity_columns=quantity_columns)
quantity_col = None quantity_col = None
for col in quantity_columns: for col in quantity_columns:
if col in sales_clean.columns: if col in sales_clean.columns:
quantity_col = col quantity_col = col
logger.debug("Found quantity column",
inventory_product_id=inventory_product_id,
quantity_column=col)
break break
if quantity_col is None: if quantity_col is None:
logger.error("Sales data must have one of the expected quantity columns",
inventory_product_id=inventory_product_id,
expected_columns=quantity_columns,
available_columns=list(sales_clean.columns))
raise ValueError(f"Sales data must have one of these columns: {quantity_columns}") raise ValueError(f"Sales data must have one of these columns: {quantity_columns}")
# Standardize to 'quantity' # Standardize to 'quantity'
if quantity_col != 'quantity': if quantity_col != 'quantity':
logger.debug("Mapping quantity column",
inventory_product_id=inventory_product_id,
from_column=quantity_col,
to_column='quantity')
sales_clean['quantity'] = sales_clean[quantity_col] sales_clean['quantity'] = sales_clean[quantity_col]
logger.info("Mapped quantity column", logger.info("Mapped quantity column",
from_column=quantity_col, from_column=quantity_col,
to_column='quantity') to_column='quantity')
logger.debug("Converting quantity to numeric",
inventory_product_id=inventory_product_id)
sales_clean['quantity'] = pd.to_numeric(sales_clean['quantity'], errors='coerce') sales_clean['quantity'] = pd.to_numeric(sales_clean['quantity'], errors='coerce')
logger.debug("Quantity conversion completed",
inventory_product_id=inventory_product_id,
non_numeric_count=sales_clean['quantity'].isna().sum())
# Remove rows with invalid quantities # Remove rows with invalid quantities
logger.debug("Removing rows with invalid quantities",
inventory_product_id=inventory_product_id)
sales_clean = sales_clean.dropna(subset=['quantity']) sales_clean = sales_clean.dropna(subset=['quantity'])
logger.debug("NaN rows removed",
inventory_product_id=inventory_product_id,
remaining_records=len(sales_clean))
sales_clean = sales_clean[sales_clean['quantity'] >= 0] # No negative sales sales_clean = sales_clean[sales_clean['quantity'] >= 0] # No negative sales
logger.debug("Negative sales removed",
inventory_product_id=inventory_product_id,
remaining_records=len(sales_clean))
# Filter for the specific product if inventory_product_id column exists # Filter for the specific product if inventory_product_id column exists
logger.debug("Checking for inventory_product_id column",
inventory_product_id=inventory_product_id,
has_inventory_column='inventory_product_id' in sales_clean.columns)
if 'inventory_product_id' in sales_clean.columns: if 'inventory_product_id' in sales_clean.columns:
logger.debug("Filtering for specific product",
inventory_product_id=inventory_product_id,
products_in_data=sales_clean['inventory_product_id'].unique()[:5].tolist()) # Show first 5
original_count = len(sales_clean)
sales_clean = sales_clean[sales_clean['inventory_product_id'] == inventory_product_id] sales_clean = sales_clean[sales_clean['inventory_product_id'] == inventory_product_id]
logger.debug("Product filtering completed",
inventory_product_id=inventory_product_id,
original_count=original_count,
filtered_count=len(sales_clean))
# Remove duplicate dates (keep the one with highest quantity) # Remove duplicate dates (keep the one with highest quantity)
logger.debug("Removing duplicate dates",
inventory_product_id=inventory_product_id,
before_dedupe=len(sales_clean))
sales_clean = sales_clean.sort_values(['date', 'quantity'], ascending=[True, False]) sales_clean = sales_clean.sort_values(['date', 'quantity'], ascending=[True, False])
sales_clean = sales_clean.drop_duplicates(subset=['date'], keep='first') sales_clean = sales_clean.drop_duplicates(subset=['date'], keep='first')
logger.debug("Duplicate dates removed",
inventory_product_id=inventory_product_id,
after_dedupe=len(sales_clean))
logger.debug("Sales data processing completed",
inventory_product_id=inventory_product_id,
final_records=len(sales_clean))
return sales_clean return sales_clean
async def _aggregate_daily_sales(self, sales_data: pd.DataFrame) -> pd.DataFrame: async def _aggregate_daily_sales(self, sales_data: pd.DataFrame) -> pd.DataFrame:
"""Aggregate sales to daily level with improved date handling""" """Aggregate sales to daily level with improved date handling"""
logger.debug("Starting daily sales aggregation",
input_records=len(sales_data),
columns=list(sales_data.columns))
if sales_data.empty: if sales_data.empty:
logger.debug("Sales data is empty, returning empty DataFrame")
return pd.DataFrame(columns=['date', 'quantity']) return pd.DataFrame(columns=['date', 'quantity'])
logger.debug("Starting groupby aggregation",
unique_dates=sales_data['date'].nunique(),
date_range=(sales_data['date'].min(), sales_data['date'].max()))
# Group by date and sum quantities # Group by date and sum quantities
daily_sales = sales_data.groupby('date').agg({ daily_sales = sales_data.groupby('date').agg({
'quantity': 'sum' 'quantity': 'sum'
}).reset_index() }).reset_index()
logger.debug("Groupby aggregation completed",
aggregated_records=len(daily_sales))
# Ensure we have data for all dates in the range (fill gaps with 0) # Ensure we have data for all dates in the range (fill gaps with 0)
logger.debug("Creating full date range",
start_date=daily_sales['date'].min(),
end_date=daily_sales['date'].max())
date_range = pd.date_range( date_range = pd.date_range(
start=daily_sales['date'].min(), start=daily_sales['date'].min(),
end=daily_sales['date'].max(), end=daily_sales['date'].max(),
freq='D' freq='D'
) )
logger.debug("Date range created",
total_dates=len(date_range))
full_date_df = pd.DataFrame({'date': date_range}) full_date_df = pd.DataFrame({'date': date_range})
logger.debug("Starting merge to fill missing dates",
full_date_records=len(full_date_df),
aggregated_records=len(daily_sales))
daily_sales = full_date_df.merge(daily_sales, on='date', how='left') daily_sales = full_date_df.merge(daily_sales, on='date', how='left')
logger.debug("Missing date filling merge completed",
final_records=len(daily_sales))
daily_sales['quantity'] = daily_sales['quantity'].fillna(0) # Fill missing days with 0 sales daily_sales['quantity'] = daily_sales['quantity'].fillna(0) # Fill missing days with 0 sales
logger.debug("NaN filling completed",
remaining_nan_count=daily_sales['quantity'].isna().sum(),
zero_filled_count=(daily_sales['quantity'] == 0).sum())
logger.debug("Daily sales aggregation completed",
final_records=len(daily_sales),
final_columns=len(daily_sales.columns))
return daily_sales return daily_sales
@@ -466,6 +645,10 @@ class EnhancedBakeryDataProcessor:
daily_sales: pd.DataFrame, daily_sales: pd.DataFrame,
weather_data: pd.DataFrame) -> pd.DataFrame: weather_data: pd.DataFrame) -> pd.DataFrame:
"""Merge weather features with enhanced Madrid-specific handling""" """Merge weather features with enhanced Madrid-specific handling"""
logger.debug("Starting weather features merge",
daily_sales_records=len(daily_sales),
weather_data_records=len(weather_data) if not weather_data.empty else 0,
weather_columns=list(weather_data.columns) if not weather_data.empty else [])
# Define weather_defaults OUTSIDE try block to fix scope error # Define weather_defaults OUTSIDE try block to fix scope error
weather_defaults = { weather_defaults = {
@@ -477,27 +660,38 @@ class EnhancedBakeryDataProcessor:
} }
if weather_data.empty: if weather_data.empty:
logger.debug("Weather data is empty, adding default columns")
# Add default weather columns # Add default weather columns
for feature, default_value in weather_defaults.items(): for feature, default_value in weather_defaults.items():
daily_sales[feature] = default_value daily_sales[feature] = default_value
logger.debug("Default weather columns added",
features_added=list(weather_defaults.keys()))
return daily_sales return daily_sales
try: try:
weather_clean = weather_data.copy() weather_clean = weather_data.copy()
logger.debug("Weather data copied",
records=len(weather_clean),
columns=list(weather_clean.columns))
# Standardize date column # Standardize date column
if 'date' not in weather_clean.columns and 'ds' in weather_clean.columns: if 'date' not in weather_clean.columns and 'ds' in weather_clean.columns:
logger.debug("Renaming ds column to date")
weather_clean = weather_clean.rename(columns={'ds': 'date'}) weather_clean = weather_clean.rename(columns={'ds': 'date'})
# CRITICAL FIX: Ensure both DataFrames have compatible datetime formats # CRITICAL FIX: Ensure both DataFrames have compatible datetime formats
logger.debug("Converting weather data date column to datetime")
weather_clean['date'] = pd.to_datetime(weather_clean['date']) weather_clean['date'] = pd.to_datetime(weather_clean['date'])
logger.debug("Converting daily sales date column to datetime")
daily_sales['date'] = pd.to_datetime(daily_sales['date']) daily_sales['date'] = pd.to_datetime(daily_sales['date'])
# NEW FIX: Normalize both to timezone-naive datetime for merge compatibility # NEW FIX: Normalize both to timezone-naive datetime for merge compatibility
if weather_clean['date'].dt.tz is not None: if weather_clean['date'].dt.tz is not None:
logger.debug("Removing timezone from weather data")
weather_clean['date'] = weather_clean['date'].dt.tz_convert('UTC').dt.tz_localize(None) weather_clean['date'] = weather_clean['date'].dt.tz_convert('UTC').dt.tz_localize(None)
if daily_sales['date'].dt.tz is not None: if daily_sales['date'].dt.tz is not None:
logger.debug("Removing timezone from daily sales data")
daily_sales['date'] = daily_sales['date'].dt.tz_convert('UTC').dt.tz_localize(None) daily_sales['date'] = daily_sales['date'].dt.tz_convert('UTC').dt.tz_localize(None)
# Map weather columns to standard names # Map weather columns to standard names
@@ -510,14 +704,24 @@ class EnhancedBakeryDataProcessor:
} }
weather_features = ['date'] weather_features = ['date']
logger.debug("Mapping weather columns",
mapping_attempts=list(weather_mapping.keys()))
for standard_name, possible_names in weather_mapping.items(): for standard_name, possible_names in weather_mapping.items():
for possible_name in possible_names: for possible_name in possible_names:
if possible_name in weather_clean.columns: if possible_name in weather_clean.columns:
logger.debug("Processing weather column",
standard_name=standard_name,
possible_name=possible_name,
records=len(weather_clean))
# Extract numeric values using robust helper function # Extract numeric values using robust helper function
try: try:
# Check if column contains dict-like objects # Check if column contains dict-like objects
logger.debug("Checking for dict objects in weather column")
has_dicts = weather_clean[possible_name].apply(lambda x: isinstance(x, dict)).any() has_dicts = weather_clean[possible_name].apply(lambda x: isinstance(x, dict)).any()
logger.debug("Dict object check completed",
has_dicts=has_dicts)
if has_dicts: if has_dicts:
logger.warning(f"Weather column {possible_name} contains dict objects, extracting numeric values") logger.warning(f"Weather column {possible_name} contains dict objects, extracting numeric values")
@@ -525,9 +729,14 @@ class EnhancedBakeryDataProcessor:
weather_clean[standard_name] = weather_clean[possible_name].apply( weather_clean[standard_name] = weather_clean[possible_name].apply(
self._extract_numeric_from_dict self._extract_numeric_from_dict
) )
logger.debug("Dict extraction completed for weather column",
extracted_column=standard_name,
extracted_count=weather_clean[standard_name].notna().sum())
else: else:
# Direct numeric conversion for simple values # Direct numeric conversion for simple values
logger.debug("Performing direct numeric conversion")
weather_clean[standard_name] = pd.to_numeric(weather_clean[possible_name], errors='coerce') weather_clean[standard_name] = pd.to_numeric(weather_clean[possible_name], errors='coerce')
logger.debug("Direct numeric conversion completed")
except Exception as e: except Exception as e:
logger.warning(f"Error converting weather column {possible_name}: {e}") logger.warning(f"Error converting weather column {possible_name}: {e}")
# Fallback: try to extract from each value # Fallback: try to extract from each value
@@ -535,28 +744,55 @@ class EnhancedBakeryDataProcessor:
self._extract_numeric_from_dict self._extract_numeric_from_dict
) )
weather_features.append(standard_name) weather_features.append(standard_name)
logger.debug("Added weather feature to list",
feature=standard_name)
break break
# Keep only the features we found # Keep only the features we found
logger.debug("Selecting weather features",
selected_features=weather_features)
weather_clean = weather_clean[weather_features].copy() weather_clean = weather_clean[weather_features].copy()
# Merge with sales data # Merge with sales data
logger.debug("Starting merge operation",
daily_sales_rows=len(daily_sales),
weather_rows=len(weather_clean),
date_range_sales=(daily_sales['date'].min(), daily_sales['date'].max()) if len(daily_sales) > 0 else None,
date_range_weather=(weather_clean['date'].min(), weather_clean['date'].max()) if len(weather_clean) > 0 else None)
merged = daily_sales.merge(weather_clean, on='date', how='left') merged = daily_sales.merge(weather_clean, on='date', how='left')
logger.debug("Merge completed",
merged_rows=len(merged),
merge_type='left')
# Fill missing weather values with Madrid-appropriate defaults # Fill missing weather values with Madrid-appropriate defaults
logger.debug("Filling missing weather values",
features_to_fill=list(weather_defaults.keys()))
for feature, default_value in weather_defaults.items(): for feature, default_value in weather_defaults.items():
if feature in merged.columns: if feature in merged.columns:
logger.debug("Processing feature for NaN fill",
feature=feature,
nan_count=merged[feature].isna().sum())
# Ensure the column is numeric before filling # Ensure the column is numeric before filling
merged[feature] = pd.to_numeric(merged[feature], errors='coerce') merged[feature] = pd.to_numeric(merged[feature], errors='coerce')
merged[feature] = merged[feature].fillna(default_value) merged[feature] = merged[feature].fillna(default_value)
logger.debug("NaN fill completed for feature",
feature=feature,
final_nan_count=merged[feature].isna().sum())
logger.debug("Weather features merge completed",
final_rows=len(merged),
final_columns=len(merged.columns))
return merged return merged
except Exception as e: except Exception as e:
logger.warning("Error merging weather data", error=str(e)) logger.warning("Error merging weather data", error=str(e), exc_info=True)
# Add default weather columns if merge fails # Add default weather columns if merge fails
for feature, default_value in weather_defaults.items(): for feature, default_value in weather_defaults.items():
daily_sales[feature] = default_value daily_sales[feature] = default_value
logger.debug("Default weather columns added after merge failure",
features_added=list(weather_defaults.keys()))
return daily_sales return daily_sales
@@ -564,28 +800,43 @@ class EnhancedBakeryDataProcessor:
daily_sales: pd.DataFrame, daily_sales: pd.DataFrame,
traffic_data: pd.DataFrame) -> pd.DataFrame: traffic_data: pd.DataFrame) -> pd.DataFrame:
"""Merge traffic features with enhanced Madrid-specific handling""" """Merge traffic features with enhanced Madrid-specific handling"""
logger.debug("Starting traffic features merge",
daily_sales_records=len(daily_sales),
traffic_data_records=len(traffic_data) if not traffic_data.empty else 0,
traffic_columns=list(traffic_data.columns) if not traffic_data.empty else [])
if traffic_data.empty: if traffic_data.empty:
logger.debug("Traffic data is empty, adding default column")
# Add default traffic column # Add default traffic column
daily_sales['traffic_volume'] = 100.0 # Neutral traffic level daily_sales['traffic_volume'] = 100.0 # Neutral traffic level
logger.debug("Default traffic column added",
default_value=100.0)
return daily_sales return daily_sales
try: try:
traffic_clean = traffic_data.copy() traffic_clean = traffic_data.copy()
logger.debug("Traffic data copied",
records=len(traffic_clean),
columns=list(traffic_clean.columns))
# Standardize date column # Standardize date column
if 'date' not in traffic_clean.columns and 'ds' in traffic_clean.columns: if 'date' not in traffic_clean.columns and 'ds' in traffic_clean.columns:
logger.debug("Renaming ds column to date")
traffic_clean = traffic_clean.rename(columns={'ds': 'date'}) traffic_clean = traffic_clean.rename(columns={'ds': 'date'})
# CRITICAL FIX: Ensure both DataFrames have compatible datetime formats # CRITICAL FIX: Ensure both DataFrames have compatible datetime formats
logger.debug("Converting traffic data date column to datetime")
traffic_clean['date'] = pd.to_datetime(traffic_clean['date']) traffic_clean['date'] = pd.to_datetime(traffic_clean['date'])
logger.debug("Converting daily sales date column to datetime")
daily_sales['date'] = pd.to_datetime(daily_sales['date']) daily_sales['date'] = pd.to_datetime(daily_sales['date'])
# NEW FIX: Normalize both to timezone-naive datetime for merge compatibility # NEW FIX: Normalize both to timezone-naive datetime for merge compatibility
if traffic_clean['date'].dt.tz is not None: if traffic_clean['date'].dt.tz is not None:
logger.debug("Removing timezone from traffic data")
traffic_clean['date'] = traffic_clean['date'].dt.tz_convert('UTC').dt.tz_localize(None) traffic_clean['date'] = traffic_clean['date'].dt.tz_convert('UTC').dt.tz_localize(None)
if daily_sales['date'].dt.tz is not None: if daily_sales['date'].dt.tz is not None:
logger.debug("Removing timezone from daily sales data")
daily_sales['date'] = daily_sales['date'].dt.tz_convert('UTC').dt.tz_localize(None) daily_sales['date'] = daily_sales['date'].dt.tz_convert('UTC').dt.tz_localize(None)
# Map traffic columns to standard names # Map traffic columns to standard names
@@ -597,14 +848,24 @@ class EnhancedBakeryDataProcessor:
} }
traffic_features = ['date'] traffic_features = ['date']
logger.debug("Mapping traffic columns",
mapping_attempts=list(traffic_mapping.keys()))
for standard_name, possible_names in traffic_mapping.items(): for standard_name, possible_names in traffic_mapping.items():
for possible_name in possible_names: for possible_name in possible_names:
if possible_name in traffic_clean.columns: if possible_name in traffic_clean.columns:
logger.debug("Processing traffic column",
standard_name=standard_name,
possible_name=possible_name,
records=len(traffic_clean))
# Extract numeric values using robust helper function # Extract numeric values using robust helper function
try: try:
# Check if column contains dict-like objects # Check if column contains dict-like objects
logger.debug("Checking for dict objects in traffic column")
has_dicts = traffic_clean[possible_name].apply(lambda x: isinstance(x, dict)).any() has_dicts = traffic_clean[possible_name].apply(lambda x: isinstance(x, dict)).any()
logger.debug("Dict object check completed",
has_dicts=has_dicts)
if has_dicts: if has_dicts:
logger.warning(f"Traffic column {possible_name} contains dict objects, extracting numeric values") logger.warning(f"Traffic column {possible_name} contains dict objects, extracting numeric values")
@@ -612,9 +873,14 @@ class EnhancedBakeryDataProcessor:
traffic_clean[standard_name] = traffic_clean[possible_name].apply( traffic_clean[standard_name] = traffic_clean[possible_name].apply(
self._extract_numeric_from_dict self._extract_numeric_from_dict
) )
logger.debug("Dict extraction completed for traffic column",
extracted_column=standard_name,
extracted_count=traffic_clean[standard_name].notna().sum())
else: else:
# Direct numeric conversion for simple values # Direct numeric conversion for simple values
logger.debug("Performing direct numeric conversion")
traffic_clean[standard_name] = pd.to_numeric(traffic_clean[possible_name], errors='coerce') traffic_clean[standard_name] = pd.to_numeric(traffic_clean[possible_name], errors='coerce')
logger.debug("Direct numeric conversion completed")
except Exception as e: except Exception as e:
logger.warning(f"Error converting traffic column {possible_name}: {e}") logger.warning(f"Error converting traffic column {possible_name}: {e}")
# Fallback: try to extract from each value # Fallback: try to extract from each value
@@ -622,14 +888,28 @@ class EnhancedBakeryDataProcessor:
self._extract_numeric_from_dict self._extract_numeric_from_dict
) )
traffic_features.append(standard_name) traffic_features.append(standard_name)
logger.debug("Added traffic feature to list",
feature=standard_name)
break break
# Keep only the features we found # Keep only the features we found
logger.debug("Selecting traffic features",
selected_features=traffic_features)
traffic_clean = traffic_clean[traffic_features].copy() traffic_clean = traffic_clean[traffic_features].copy()
# Merge with sales data # Merge with sales data
logger.debug("Starting traffic merge operation",
daily_sales_rows=len(daily_sales),
traffic_rows=len(traffic_clean),
date_range_sales=(daily_sales['date'].min(), daily_sales['date'].max()) if len(daily_sales) > 0 else None,
date_range_traffic=(traffic_clean['date'].min(), traffic_clean['date'].max()) if len(traffic_clean) > 0 else None)
merged = daily_sales.merge(traffic_clean, on='date', how='left') merged = daily_sales.merge(traffic_clean, on='date', how='left')
logger.debug("Traffic merge completed",
merged_rows=len(merged),
merge_type='left')
# Fill missing traffic values with reasonable defaults # Fill missing traffic values with reasonable defaults
traffic_defaults = { traffic_defaults = {
'traffic_volume': 100.0, 'traffic_volume': 100.0,
@@ -638,18 +918,31 @@ class EnhancedBakeryDataProcessor:
'average_speed': 30.0 # km/h typical for Madrid 'average_speed': 30.0 # km/h typical for Madrid
} }
logger.debug("Filling missing traffic values",
features_to_fill=list(traffic_defaults.keys()))
for feature, default_value in traffic_defaults.items(): for feature, default_value in traffic_defaults.items():
if feature in merged.columns: if feature in merged.columns:
logger.debug("Processing traffic feature for NaN fill",
feature=feature,
nan_count=merged[feature].isna().sum())
# Ensure the column is numeric before filling # Ensure the column is numeric before filling
merged[feature] = pd.to_numeric(merged[feature], errors='coerce') merged[feature] = pd.to_numeric(merged[feature], errors='coerce')
merged[feature] = merged[feature].fillna(default_value) merged[feature] = merged[feature].fillna(default_value)
logger.debug("NaN fill completed for traffic feature",
feature=feature,
final_nan_count=merged[feature].isna().sum())
logger.debug("Traffic features merge completed",
final_rows=len(merged),
final_columns=len(merged.columns))
return merged return merged
except Exception as e: except Exception as e:
logger.warning("Error merging traffic data", error=str(e)) logger.warning("Error merging traffic data", error=str(e), exc_info=True)
# Add default traffic column if merge fails # Add default traffic column if merge fails
daily_sales['traffic_volume'] = 100.0 daily_sales['traffic_volume'] = 100.0
logger.debug("Default traffic column added after merge failure",
default_value=100.0)
return daily_sales return daily_sales
def _engineer_features(self, df: pd.DataFrame) -> pd.DataFrame: def _engineer_features(self, df: pd.DataFrame) -> pd.DataFrame:
@@ -774,12 +1067,26 @@ class EnhancedBakeryDataProcessor:
""" """
df = df.copy() df = df.copy()
logger.info("Adding advanced features (lagged, rolling, cyclical, trends)") logger.info("Adding advanced features (lagged, rolling, cyclical, trends)",
input_rows=len(df),
input_columns=len(df.columns))
# Log column dtypes to identify potential issues
logger.debug("Input dataframe dtypes",
dtypes={col: str(dtype) for col, dtype in df.dtypes.items()},
date_column_exists='date' in df.columns)
# Reset feature engineer to clear previous features # Reset feature engineer to clear previous features
logger.debug("Initializing AdvancedFeatureEngineer")
self.feature_engineer = AdvancedFeatureEngineer() self.feature_engineer = AdvancedFeatureEngineer()
# Create all advanced features at once # Create all advanced features at once
logger.debug("Starting creation of advanced features",
include_lags=True,
include_rolling=True,
include_interactions=True,
include_cyclical=True)
df = self.feature_engineer.create_all_features( df = self.feature_engineer.create_all_features(
df, df,
date_column='date', date_column='date',
@@ -789,8 +1096,16 @@ class EnhancedBakeryDataProcessor:
include_cyclical=True include_cyclical=True
) )
logger.debug("Advanced features creation completed",
output_rows=len(df),
output_columns=len(df.columns))
# Fill NA values from lagged and rolling features # Fill NA values from lagged and rolling features
logger.debug("Starting NA value filling",
na_counts={col: df[col].isna().sum() for col in df.columns if df[col].isna().any()})
df = self.feature_engineer.fill_na_values(df, strategy='forward_backward') df = self.feature_engineer.fill_na_values(df, strategy='forward_backward')
logger.debug("NA value filling completed",
remaining_na_counts={col: df[col].isna().sum() for col in df.columns if df[col].isna().any()})
# Store created feature columns for later reference # Store created feature columns for later reference
created_features = self.feature_engineer.get_feature_columns() created_features = self.feature_engineer.get_feature_columns()

View File

@@ -77,6 +77,7 @@ class EnhancedBakeryMLTrainer:
tenant_id: Tenant identifier tenant_id: Tenant identifier
training_dataset: Prepared training dataset with aligned dates training_dataset: Prepared training dataset with aligned dates
job_id: Training job identifier job_id: Training job identifier
session: Database session to use (if None, creates one)
Returns: Returns:
Dictionary with training results for each product Dictionary with training results for each product
@@ -89,68 +90,20 @@ class EnhancedBakeryMLTrainer:
tenant_id=tenant_id) tenant_id=tenant_id)
try: try:
# Get database session and repositories # Use provided session or create new one to prevent nested sessions and deadlocks
async with self.database_manager.get_session() as db_session: should_create_session = session is None
repos = await self._get_repositories(db_session) db_session = session if session is not None else None
# Convert sales data to DataFrame # Use the provided session or create a new one if needed
sales_df = pd.DataFrame(training_dataset.sales_data) if should_create_session:
weather_df = pd.DataFrame(training_dataset.weather_data) async with self.database_manager.get_session() as db_session:
traffic_df = pd.DataFrame(training_dataset.traffic_data) return await self._execute_training_pipeline(
tenant_id, training_dataset, job_id, db_session
# Validate input data
await self._validate_input_data(sales_df, tenant_id)
# Get unique products from the sales data
products = sales_df['inventory_product_id'].unique().tolist()
# Debug: Log sales data details to understand why only one product is found
total_sales_records = len(sales_df)
sales_by_product = sales_df.groupby('inventory_product_id').size().to_dict()
logger.info("Enhanced training pipeline - Sales data analysis",
total_sales_records=total_sales_records,
products_count=len(products),
products=products,
sales_by_product=sales_by_product)
if len(products) == 1:
logger.warning("Only ONE product found in sales data - this may indicate a data fetching issue",
tenant_id=tenant_id,
single_product_id=products[0],
total_sales_records=total_sales_records)
elif len(products) == 0:
raise ValueError("No products found in sales data")
else:
logger.info("Multiple products detected for training",
products_count=len(products))
# Event 1: Training Started (0%) - update with actual product count AND time estimates
# Calculate accurate time estimates now that we know the actual product count
from app.utils.time_estimation import (
calculate_initial_estimate,
calculate_estimated_completion_time,
get_historical_average_estimate
)
# Try to get historical average for more accurate estimates
try:
historical_avg = await get_historical_average_estimate(
db_session,
tenant_id
) )
avg_time_per_product = historical_avg if historical_avg else 60.0 else:
logger.info("Using historical average for time estimation", # Use the provided session (no context manager needed since caller manages it)
avg_time_per_product=avg_time_per_product, return await self._execute_training_pipeline(
has_historical_data=historical_avg is not None) tenant_id, training_dataset, job_id, session
except Exception as e:
logger.warning("Could not get historical average, using default",
error=str(e))
avg_time_per_product = 60.0
estimated_duration_minutes = calculate_initial_estimate(
total_products=len(products),
avg_training_time_per_product=avg_time_per_product
) )
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes) estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
@@ -177,7 +130,7 @@ class EnhancedBakeryMLTrainer:
# Process data for each product using enhanced processor # Process data for each product using enhanced processor
logger.info("Processing data using enhanced processor") logger.info("Processing data using enhanced processor")
processed_data = await self._process_all_products_enhanced( processed_data = await self._process_all_products_enhanced(
sales_df, weather_df, traffic_df, products, tenant_id, job_id sales_df, weather_df, traffic_df, products, tenant_id, job_id, session
) )
# Categorize all products for category-specific forecasting # Categorize all products for category-specific forecasting
@@ -302,11 +255,219 @@ class EnhancedBakeryMLTrainer:
raise raise
async def _execute_training_pipeline(self, tenant_id: str, training_dataset: TrainingDataSet,
job_id: str, session) -> Dict[str, Any]:
"""
Execute the training pipeline with the given session.
This is extracted to avoid code duplication when handling provided vs. created sessions.
"""
# Get repositories with the session
repos = await self._get_repositories(session)
# Convert sales data to DataFrame
sales_df = pd.DataFrame(training_dataset.sales_data)
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
# Validate input data
await self._validate_input_data(sales_df, tenant_id)
# Get unique products from the sales data
products = sales_df['inventory_product_id'].unique().tolist()
# Debug: Log sales data details to understand why only one product is found
total_sales_records = len(sales_df)
sales_by_product = sales_df.groupby('inventory_product_id').size().to_dict()
logger.info("Enhanced training pipeline - Sales data analysis",
total_sales_records=total_sales_records,
products_count=len(products),
products=products,
sales_by_product=sales_by_product)
if len(products) == 1:
logger.warning("Only ONE product found in sales data - this may indicate a data fetching issue",
tenant_id=tenant_id,
single_product_id=products[0],
total_sales_records=total_sales_records)
elif len(products) == 0:
raise ValueError("No products found in sales data")
else:
logger.info("Multiple products detected for training",
products_count=len(products))
# Event 1: Training Started (0%) - update with actual product count AND time estimates
# Calculate accurate time estimates now that we know the actual product count
from app.utils.time_estimation import (
calculate_initial_estimate,
calculate_estimated_completion_time,
get_historical_average_estimate
)
# Try to get historical average for more accurate estimates
try:
historical_avg = await get_historical_average_estimate(
session,
tenant_id
)
avg_time_per_product = historical_avg if historical_avg else 60.0
logger.info("Using historical average for time estimation",
avg_time_per_product=avg_time_per_product,
has_historical_data=historical_avg is not None)
except Exception as e:
logger.warning("Could not get historical average, using default",
error=str(e))
avg_time_per_product = 60.0
estimated_duration_minutes = calculate_initial_estimate(
total_products=len(products),
avg_training_time_per_product=avg_time_per_product
)
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
# Note: Initial event was already published by API endpoint with estimated product count,
# this updates with real count and recalculated time estimates based on actual data
await publish_training_started(
job_id=job_id,
tenant_id=tenant_id,
total_products=len(products),
estimated_duration_minutes=estimated_duration_minutes,
estimated_completion_time=estimated_completion_time.isoformat()
)
# Create initial training log entry
await repos['training_log'].update_log_progress(
job_id, 5, "data_processing", "running"
)
# ✅ FIX: Flush the session to ensure the update is committed before proceeding
# This prevents deadlocks when training methods need to acquire locks
await session.flush()
logger.debug("Flushed session after initial progress update")
# Process data for each product using enhanced processor
logger.info("Processing data using enhanced processor")
processed_data = await self._process_all_products_enhanced(
sales_df, weather_df, traffic_df, products, tenant_id, job_id, session
)
# Categorize all products for category-specific forecasting
logger.info("Categorizing products for optimized forecasting")
product_categories = await self._categorize_all_products(
sales_df, processed_data
)
logger.info("Product categorization complete",
total_products=len(product_categories),
categories_breakdown={cat.value: sum(1 for c in product_categories.values() if c == cat)
for cat in set(product_categories.values())})
# Event 2: Data Analysis (20%)
# Recalculate time remaining based on elapsed time
start_time = await repos['training_log'].get_start_time(job_id)
elapsed_seconds = 0
if start_time:
elapsed_seconds = int((datetime.now(timezone.utc) - start_time).total_seconds())
# Estimate remaining time: we've done ~20% of work (data analysis)
# Remaining 80% includes training all products
products_to_train = len(processed_data)
estimated_remaining_seconds = int(products_to_train * avg_time_per_product)
# Recalculate estimated completion time
estimated_completion_time_data_analysis = calculate_estimated_completion_time(
estimated_remaining_seconds / 60
)
await publish_data_analysis(
job_id,
tenant_id,
f"Data analysis completed for {len(processed_data)} products",
estimated_time_remaining_seconds=estimated_remaining_seconds,
estimated_completion_time=estimated_completion_time_data_analysis.isoformat()
)
# Train models for each processed product with progress aggregation
logger.info("Training models with repository integration and progress aggregation")
# Create progress tracker for parallel product training (20-80%)
progress_tracker = ParallelProductProgressTracker(
job_id=job_id,
tenant_id=tenant_id,
total_products=len(processed_data)
)
# Train all models in parallel (without DB writes to avoid session conflicts)
# ✅ FIX: Pass session to prevent nested session issues and deadlocks
training_results = await self._train_all_models_enhanced(
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories, session
)
# Write all training results to database sequentially (after parallel training completes)
logger.info("Writing training results to database sequentially")
training_results = await self._write_training_results_to_database(
tenant_id, job_id, training_results, repos
)
# Calculate overall training summary with enhanced metrics
summary = await self._calculate_enhanced_training_summary(
training_results, repos, tenant_id
)
# Calculate successful and failed trainings
successful_trainings = len([r for r in training_results.values() if r.get('status') == 'success'])
failed_trainings = len([r for r in training_results.values() if r.get('status') == 'error'])
total_duration = sum([r.get('training_time_seconds', 0) for r in training_results.values()])
# Event 4: Training Completed (100%)
await publish_training_completed(
job_id,
tenant_id,
successful_trainings,
failed_trainings,
total_duration
)
# Create comprehensive result with repository data
result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "completed",
"products_trained": len([r for r in training_results.values() if r.get('status') == 'success']),
"products_failed": len([r for r in training_results.values() if r.get('status') == 'error']),
"products_skipped": len([r for r in training_results.values() if r.get('status') == 'skipped']),
"total_products": len(products),
"training_results": training_results,
"enhanced_summary": summary,
"models_trained": summary.get('models_created', {}),
"data_info": {
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat(),
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
},
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
},
"repository_metadata": {
"total_records_created": summary.get('total_db_records', 0),
"performance_metrics_stored": summary.get('performance_metrics_created', 0),
"artifacts_created": summary.get('artifacts_created', 0)
},
"completed_at": datetime.now().isoformat()
}
logger.info("Enhanced ML training pipeline completed successfully",
job_id=job_id,
models_created=len([r for r in training_results.values() if r.get('status') == 'success']))
return result
async def train_single_product_model(self, async def train_single_product_model(self,
tenant_id: str, tenant_id: str,
inventory_product_id: str, inventory_product_id: str,
training_data: pd.DataFrame, training_data: pd.DataFrame,
job_id: str = None) -> Dict[str, Any]: job_id: str = None,
session=None) -> Dict[str, Any]:
""" """
Train a model for a single product using repository pattern. Train a model for a single product using repository pattern.
@@ -315,6 +476,7 @@ class EnhancedBakeryMLTrainer:
inventory_product_id: Specific inventory product to train inventory_product_id: Specific inventory product to train
training_data: Prepared training DataFrame for the product training_data: Prepared training DataFrame for the product
job_id: Training job identifier (optional) job_id: Training job identifier (optional)
session: Database session to use (if None, creates one)
Returns: Returns:
Dictionary with model training results Dictionary with model training results
@@ -329,9 +491,14 @@ class EnhancedBakeryMLTrainer:
data_points=len(training_data)) data_points=len(training_data))
try: try:
# Get database session and repositories # Use provided session or create new one to prevent nested sessions and deadlocks
async with self.database_manager.get_session() as db_session: should_create_session = session is None
repos = await self._get_repositories(db_session) db_session = session if session is not None else None
if should_create_session:
# Only create a session if one wasn't provided
async with self.database_manager.get_session() as db_session:
repos = await self._get_repositories(db_session)
# Validate input data # Validate input data
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS: if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
@@ -367,7 +534,8 @@ class EnhancedBakeryMLTrainer:
product_data=training_data, product_data=training_data,
job_id=job_id, job_id=job_id,
repos=repos, repos=repos,
progress_tracker=progress_tracker progress_tracker=progress_tracker,
session=db_session # Pass the session to prevent nested sessions
) )
logger.info("Single product training completed", logger.info("Single product training completed",
@@ -391,7 +559,7 @@ class EnhancedBakeryMLTrainer:
continue continue
# Return appropriate result format # Return appropriate result format
return { result_dict = {
"job_id": job_id, "job_id": job_id,
"tenant_id": tenant_id, "tenant_id": tenant_id,
"inventory_product_id": inventory_product_id, "inventory_product_id": inventory_product_id,
@@ -403,6 +571,8 @@ class EnhancedBakeryMLTrainer:
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}" "message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
} }
return result_dict
except Exception as e: except Exception as e:
logger.error("Single product model training failed", logger.error("Single product model training failed",
job_id=job_id, job_id=job_id,
@@ -452,7 +622,8 @@ class EnhancedBakeryMLTrainer:
traffic_df: pd.DataFrame, traffic_df: pd.DataFrame,
products: List[str], products: List[str],
tenant_id: str, tenant_id: str,
job_id: str) -> Dict[str, pd.DataFrame]: job_id: str,
session=None) -> Dict[str, pd.DataFrame]:
"""Process data for all products using enhanced processor with repository tracking""" """Process data for all products using enhanced processor with repository tracking"""
processed_data = {} processed_data = {}
@@ -470,13 +641,15 @@ class EnhancedBakeryMLTrainer:
continue continue
# Use enhanced data processor with repository tracking # Use enhanced data processor with repository tracking
# Pass the session to prevent nested session issues
processed_product_data = await self.enhanced_data_processor.prepare_training_data( processed_product_data = await self.enhanced_data_processor.prepare_training_data(
sales_data=product_sales, sales_data=product_sales,
weather_data=weather_df, weather_data=weather_df,
traffic_data=traffic_df, traffic_data=traffic_df,
inventory_product_id=inventory_product_id, inventory_product_id=inventory_product_id,
tenant_id=tenant_id, tenant_id=tenant_id,
job_id=job_id job_id=job_id,
session=session # Pass the session to avoid creating new ones
) )
processed_data[inventory_product_id] = processed_product_data processed_data[inventory_product_id] = processed_product_data

View File

@@ -244,7 +244,8 @@ class EnhancedTrainingService:
training_results = await self.trainer.train_tenant_models( training_results = await self.trainer.train_tenant_models(
tenant_id=tenant_id, tenant_id=tenant_id,
training_dataset=training_dataset, training_dataset=training_dataset,
job_id=job_id job_id=job_id,
session=session # Pass the main session to avoid nested sessions
) )
await self.training_log_repo.update_log_progress( await self.training_log_repo.update_log_progress(