Fix training hang caused by nested database sessions and deadlocks

Root Cause:
The training process was hanging at the first progress update due to a
nested database session issue. The main trainer created a session and
repositories, then called prophet_manager.train_bakery_model() which
created another nested session with an advisory lock. This caused a
deadlock where:
1. Outer session had uncommitted UPDATE on model_training_logs
2. Inner session tried to acquire advisory lock
3. Neither could proceed, causing training to hang indefinitely

Changes Made:
1. prophet_manager.py:
   - Added optional 'session' parameter to train_bakery_model()
   - Refactored to use parent session if provided, otherwise create new one
   - Prevents nested session creation during training

2. hybrid_trainer.py:
   - Added optional 'session' parameter to train_hybrid_model()
   - Passes session to prophet_manager to maintain single session context

3. trainer.py:
   - Updated _train_single_product() to accept and pass session
   - Updated _train_all_models_enhanced() to accept and pass session
   - Pass db_session from main training context to all training methods
   - Added explicit db_session.flush() after critical progress update
   - This ensures updates are visible before acquiring locks

Impact:
- Eliminates nested session deadlocks
- Training now proceeds past initial progress update
- Maintains single database session context throughout training
- Prevents database transaction conflicts

Related Issues:
- Fixes training hang during onboarding process
- Not directly related to audit_metadata changes but exposed by them

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Urtzi Alfaro
2025-11-05 16:13:32 +01:00
parent 7a315afa62
commit caff49761d
3 changed files with 174 additions and 133 deletions

View File

@@ -56,7 +56,8 @@ class HybridProphetXGBoost:
inventory_product_id: str, inventory_product_id: str,
df: pd.DataFrame, df: pd.DataFrame,
job_id: str, job_id: str,
validation_split: float = 0.2 validation_split: float = 0.2,
session = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Train hybrid Prophet + XGBoost model. Train hybrid Prophet + XGBoost model.
@@ -67,6 +68,7 @@ class HybridProphetXGBoost:
df: Training data (must have 'ds', 'y' and regressor columns) df: Training data (must have 'ds', 'y' and regressor columns)
job_id: Training job identifier job_id: Training job identifier
validation_split: Fraction of data for validation validation_split: Fraction of data for validation
session: Optional database session (uses parent session if provided to avoid nested sessions)
Returns: Returns:
Dictionary with model metadata and performance metrics Dictionary with model metadata and performance metrics
@@ -80,11 +82,13 @@ class HybridProphetXGBoost:
# Step 1: Train Prophet model (base forecaster) # Step 1: Train Prophet model (base forecaster)
logger.info("Step 1: Training Prophet base model") logger.info("Step 1: Training Prophet base model")
# ✅ FIX: Pass session to prophet_manager to avoid nested session issues
prophet_result = await self.prophet_manager.train_bakery_model( prophet_result = await self.prophet_manager.train_bakery_model(
tenant_id=tenant_id, tenant_id=tenant_id,
inventory_product_id=inventory_product_id, inventory_product_id=inventory_product_id,
df=df.copy(), df=df.copy(),
job_id=job_id job_id=job_id,
session=session
) )
self.prophet_model_data = prophet_result self.prophet_model_data = prophet_result

View File

@@ -94,7 +94,8 @@ class BakeryProphetManager:
df: pd.DataFrame, df: pd.DataFrame,
job_id: str, job_id: str,
product_category: 'ProductCategory' = None, product_category: 'ProductCategory' = None,
category_hyperparameters: Dict[str, Any] = None) -> Dict[str, Any]: category_hyperparameters: Dict[str, Any] = None,
session = None) -> Dict[str, Any]:
""" """
Train a Prophet model with automatic hyperparameter optimization and distributed locking. Train a Prophet model with automatic hyperparameter optimization and distributed locking.
@@ -105,6 +106,7 @@ class BakeryProphetManager:
job_id: Training job identifier job_id: Training job identifier
product_category: Optional product category for category-specific settings product_category: Optional product category for category-specific settings
category_hyperparameters: Optional category-specific Prophet hyperparameters category_hyperparameters: Optional category-specific Prophet hyperparameters
session: Optional database session (uses parent session if provided to avoid nested sessions)
""" """
# Check disk space before starting training # Check disk space before starting training
has_space, free_gb, total_gb, used_percent = check_disk_space('/tmp', min_free_gb=0.5) has_space, free_gb, total_gb, used_percent = check_disk_space('/tmp', min_free_gb=0.5)
@@ -116,9 +118,12 @@ class BakeryProphetManager:
# Acquire distributed lock to prevent concurrent training of same product # Acquire distributed lock to prevent concurrent training of same product
lock = get_training_lock(tenant_id, inventory_product_id, use_advisory=True) lock = get_training_lock(tenant_id, inventory_product_id, use_advisory=True)
try: # Use provided session or create new one if not provided
async with self.database_manager.get_session() as session: use_parent_session = session is not None
async with lock.acquire(session):
async def _train_with_lock(db_session):
"""Inner function to perform training with lock"""
async with lock.acquire(db_session):
logger.info(f"Training optimized bakery model for {inventory_product_id} (lock acquired)") logger.info(f"Training optimized bakery model for {inventory_product_id} (lock acquired)")
# Validate input data # Validate input data
@@ -252,6 +257,17 @@ class BakeryProphetManager:
f"MAPE: {training_metrics.get('optimized_mape', 'N/A')}%") f"MAPE: {training_metrics.get('optimized_mape', 'N/A')}%")
return model_info return model_info
try:
# ✅ FIX: Use parent session if provided, otherwise create new one
# This prevents nested session issues and database deadlocks
if use_parent_session:
logger.debug(f"Using parent session for training {inventory_product_id}")
return await _train_with_lock(session)
else:
logger.debug(f"Creating new session for training {inventory_product_id}")
async with self.database_manager.get_session() as new_session:
return await _train_with_lock(new_session)
except LockAcquisitionError as e: except LockAcquisitionError as e:
logger.warning(f"Could not acquire lock for {inventory_product_id}: {e}") logger.warning(f"Could not acquire lock for {inventory_product_id}: {e}")
raise RuntimeError(f"Training already in progress for product {inventory_product_id}") raise RuntimeError(f"Training already in progress for product {inventory_product_id}")

View File

@@ -169,6 +169,11 @@ class EnhancedBakeryMLTrainer:
job_id, 5, "data_processing", "running" job_id, 5, "data_processing", "running"
) )
# ✅ FIX: Flush the session to ensure the update is committed before proceeding
# This prevents deadlocks when training methods need to acquire locks
await db_session.flush()
logger.debug("Flushed session after initial progress update")
# Process data for each product using enhanced processor # Process data for each product using enhanced processor
logger.info("Processing data using enhanced processor") logger.info("Processing data using enhanced processor")
processed_data = await self._process_all_products_enhanced( processed_data = await self._process_all_products_enhanced(
@@ -221,8 +226,9 @@ class EnhancedBakeryMLTrainer:
) )
# Train all models in parallel (without DB writes to avoid session conflicts) # Train all models in parallel (without DB writes to avoid session conflicts)
# ✅ FIX: Pass db_session to prevent nested session issues and deadlocks
training_results = await self._train_all_models_enhanced( training_results = await self._train_all_models_enhanced(
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories tenant_id, processed_data, job_id, repos, progress_tracker, product_categories, db_session
) )
# Write all training results to database sequentially (after parallel training completes) # Write all training results to database sequentially (after parallel training completes)
@@ -493,12 +499,16 @@ class EnhancedBakeryMLTrainer:
job_id: str, job_id: str,
repos: Dict, repos: Dict,
progress_tracker: ParallelProductProgressTracker, progress_tracker: ParallelProductProgressTracker,
product_category: ProductCategory = ProductCategory.UNKNOWN) -> tuple[str, Dict[str, Any]]: product_category: ProductCategory = ProductCategory.UNKNOWN,
session = None) -> tuple[str, Dict[str, Any]]:
""" """
Train a single product model - used for parallel execution with progress aggregation. Train a single product model - used for parallel execution with progress aggregation.
Note: This method ONLY trains the model and collects results. Database writes happen Note: This method ONLY trains the model and collects results. Database writes happen
separately to avoid concurrent session conflicts. separately to avoid concurrent session conflicts.
Args:
session: Database session to use for training (prevents nested session issues)
""" """
product_start_time = time.time() product_start_time = time.time()
@@ -539,13 +549,15 @@ class EnhancedBakeryMLTrainer:
category=product_category.value) category=product_category.value)
# Train the selected model # Train the selected model
# ✅ FIX: Pass session to training methods to avoid nested session issues
if model_type == "hybrid": if model_type == "hybrid":
# Train hybrid Prophet + XGBoost model # Train hybrid Prophet + XGBoost model
model_info = await self.hybrid_trainer.train_hybrid_model( model_info = await self.hybrid_trainer.train_hybrid_model(
tenant_id=tenant_id, tenant_id=tenant_id,
inventory_product_id=inventory_product_id, inventory_product_id=inventory_product_id,
df=product_data, df=product_data,
job_id=job_id job_id=job_id,
session=session
) )
model_info['model_type'] = 'hybrid_prophet_xgboost' model_info['model_type'] = 'hybrid_prophet_xgboost'
else: else:
@@ -556,7 +568,8 @@ class EnhancedBakeryMLTrainer:
df=product_data, df=product_data,
job_id=job_id, job_id=job_id,
product_category=product_category, product_category=product_category,
category_hyperparameters=category_characteristics.get('prophet_params', {}) category_hyperparameters=category_characteristics.get('prophet_params', {}),
session=session
) )
model_info['model_type'] = 'prophet_optimized' model_info['model_type'] = 'prophet_optimized'
@@ -620,12 +633,19 @@ class EnhancedBakeryMLTrainer:
job_id: str, job_id: str,
repos: Dict, repos: Dict,
progress_tracker: ParallelProductProgressTracker, progress_tracker: ParallelProductProgressTracker,
product_categories: Dict[str, ProductCategory] = None) -> Dict[str, Any]: product_categories: Dict[str, ProductCategory] = None,
"""Train models with throttled parallel execution and progress tracking""" session = None) -> Dict[str, Any]:
"""
Train models with throttled parallel execution and progress tracking
Args:
session: Database session to pass to training methods (prevents nested session issues)
"""
total_products = len(processed_data) total_products = len(processed_data)
logger.info(f"Starting throttled parallel training for {total_products} products") logger.info(f"Starting throttled parallel training for {total_products} products")
# Create training tasks for all products # Create training tasks for all products
# ✅ FIX: Pass session to prevent nested session issues and deadlocks
training_tasks = [ training_tasks = [
self._train_single_product( self._train_single_product(
tenant_id=tenant_id, tenant_id=tenant_id,
@@ -634,7 +654,8 @@ class EnhancedBakeryMLTrainer:
job_id=job_id, job_id=job_id,
repos=repos, repos=repos,
progress_tracker=progress_tracker, progress_tracker=progress_tracker,
product_category=product_categories.get(inventory_product_id, ProductCategory.UNKNOWN) if product_categories else ProductCategory.UNKNOWN product_category=product_categories.get(inventory_product_id, ProductCategory.UNKNOWN) if product_categories else ProductCategory.UNKNOWN,
session=session
) )
for inventory_product_id, product_data in processed_data.items() for inventory_product_id, product_data in processed_data.items()
] ]