Fix training hang caused by nested database sessions and deadlocks
Root Cause: The training process was hanging at the first progress update due to a nested database session issue. The main trainer created a session and repositories, then called prophet_manager.train_bakery_model() which created another nested session with an advisory lock. This caused a deadlock where: 1. Outer session had uncommitted UPDATE on model_training_logs 2. Inner session tried to acquire advisory lock 3. Neither could proceed, causing training to hang indefinitely Changes Made: 1. prophet_manager.py: - Added optional 'session' parameter to train_bakery_model() - Refactored to use parent session if provided, otherwise create new one - Prevents nested session creation during training 2. hybrid_trainer.py: - Added optional 'session' parameter to train_hybrid_model() - Passes session to prophet_manager to maintain single session context 3. trainer.py: - Updated _train_single_product() to accept and pass session - Updated _train_all_models_enhanced() to accept and pass session - Pass db_session from main training context to all training methods - Added explicit db_session.flush() after critical progress update - This ensures updates are visible before acquiring locks Impact: - Eliminates nested session deadlocks - Training now proceeds past initial progress update - Maintains single database session context throughout training - Prevents database transaction conflicts Related Issues: - Fixes training hang during onboarding process - Not directly related to audit_metadata changes but exposed by them 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -168,7 +168,12 @@ class EnhancedBakeryMLTrainer:
|
||||
await repos['training_log'].update_log_progress(
|
||||
job_id, 5, "data_processing", "running"
|
||||
)
|
||||
|
||||
|
||||
# ✅ FIX: Flush the session to ensure the update is committed before proceeding
|
||||
# This prevents deadlocks when training methods need to acquire locks
|
||||
await db_session.flush()
|
||||
logger.debug("Flushed session after initial progress update")
|
||||
|
||||
# Process data for each product using enhanced processor
|
||||
logger.info("Processing data using enhanced processor")
|
||||
processed_data = await self._process_all_products_enhanced(
|
||||
@@ -221,8 +226,9 @@ class EnhancedBakeryMLTrainer:
|
||||
)
|
||||
|
||||
# Train all models in parallel (without DB writes to avoid session conflicts)
|
||||
# ✅ FIX: Pass db_session to prevent nested session issues and deadlocks
|
||||
training_results = await self._train_all_models_enhanced(
|
||||
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories
|
||||
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories, db_session
|
||||
)
|
||||
|
||||
# Write all training results to database sequentially (after parallel training completes)
|
||||
@@ -493,12 +499,16 @@ class EnhancedBakeryMLTrainer:
|
||||
job_id: str,
|
||||
repos: Dict,
|
||||
progress_tracker: ParallelProductProgressTracker,
|
||||
product_category: ProductCategory = ProductCategory.UNKNOWN) -> tuple[str, Dict[str, Any]]:
|
||||
product_category: ProductCategory = ProductCategory.UNKNOWN,
|
||||
session = None) -> tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Train a single product model - used for parallel execution with progress aggregation.
|
||||
|
||||
Note: This method ONLY trains the model and collects results. Database writes happen
|
||||
separately to avoid concurrent session conflicts.
|
||||
|
||||
Args:
|
||||
session: Database session to use for training (prevents nested session issues)
|
||||
"""
|
||||
product_start_time = time.time()
|
||||
|
||||
@@ -539,13 +549,15 @@ class EnhancedBakeryMLTrainer:
|
||||
category=product_category.value)
|
||||
|
||||
# Train the selected model
|
||||
# ✅ FIX: Pass session to training methods to avoid nested session issues
|
||||
if model_type == "hybrid":
|
||||
# Train hybrid Prophet + XGBoost model
|
||||
model_info = await self.hybrid_trainer.train_hybrid_model(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
df=product_data,
|
||||
job_id=job_id
|
||||
job_id=job_id,
|
||||
session=session
|
||||
)
|
||||
model_info['model_type'] = 'hybrid_prophet_xgboost'
|
||||
else:
|
||||
@@ -556,7 +568,8 @@ class EnhancedBakeryMLTrainer:
|
||||
df=product_data,
|
||||
job_id=job_id,
|
||||
product_category=product_category,
|
||||
category_hyperparameters=category_characteristics.get('prophet_params', {})
|
||||
category_hyperparameters=category_characteristics.get('prophet_params', {}),
|
||||
session=session
|
||||
)
|
||||
model_info['model_type'] = 'prophet_optimized'
|
||||
|
||||
@@ -620,12 +633,19 @@ class EnhancedBakeryMLTrainer:
|
||||
job_id: str,
|
||||
repos: Dict,
|
||||
progress_tracker: ParallelProductProgressTracker,
|
||||
product_categories: Dict[str, ProductCategory] = None) -> Dict[str, Any]:
|
||||
"""Train models with throttled parallel execution and progress tracking"""
|
||||
product_categories: Dict[str, ProductCategory] = None,
|
||||
session = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Train models with throttled parallel execution and progress tracking
|
||||
|
||||
Args:
|
||||
session: Database session to pass to training methods (prevents nested session issues)
|
||||
"""
|
||||
total_products = len(processed_data)
|
||||
logger.info(f"Starting throttled parallel training for {total_products} products")
|
||||
|
||||
# Create training tasks for all products
|
||||
# ✅ FIX: Pass session to prevent nested session issues and deadlocks
|
||||
training_tasks = [
|
||||
self._train_single_product(
|
||||
tenant_id=tenant_id,
|
||||
@@ -634,7 +654,8 @@ class EnhancedBakeryMLTrainer:
|
||||
job_id=job_id,
|
||||
repos=repos,
|
||||
progress_tracker=progress_tracker,
|
||||
product_category=product_categories.get(inventory_product_id, ProductCategory.UNKNOWN) if product_categories else ProductCategory.UNKNOWN
|
||||
product_category=product_categories.get(inventory_product_id, ProductCategory.UNKNOWN) if product_categories else ProductCategory.UNKNOWN,
|
||||
session=session
|
||||
)
|
||||
for inventory_product_id, product_data in processed_data.items()
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user