Fix training hang caused by nested database sessions and deadlocks
Root Cause: The training process was hanging at the first progress update due to a nested database session issue. The main trainer created a session and repositories, then called prophet_manager.train_bakery_model() which created another nested session with an advisory lock. This caused a deadlock where: 1. Outer session had uncommitted UPDATE on model_training_logs 2. Inner session tried to acquire advisory lock 3. Neither could proceed, causing training to hang indefinitely Changes Made: 1. prophet_manager.py: - Added optional 'session' parameter to train_bakery_model() - Refactored to use parent session if provided, otherwise create new one - Prevents nested session creation during training 2. hybrid_trainer.py: - Added optional 'session' parameter to train_hybrid_model() - Passes session to prophet_manager to maintain single session context 3. trainer.py: - Updated _train_single_product() to accept and pass session - Updated _train_all_models_enhanced() to accept and pass session - Pass db_session from main training context to all training methods - Added explicit db_session.flush() after critical progress update - This ensures updates are visible before acquiring locks Impact: - Eliminates nested session deadlocks - Training now proceeds past initial progress update - Maintains single database session context throughout training - Prevents database transaction conflicts Related Issues: - Fixes training hang during onboarding process - Not directly related to audit_metadata changes but exposed by them 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -56,7 +56,8 @@ class HybridProphetXGBoost:
|
|||||||
inventory_product_id: str,
|
inventory_product_id: str,
|
||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
validation_split: float = 0.2
|
validation_split: float = 0.2,
|
||||||
|
session = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Train hybrid Prophet + XGBoost model.
|
Train hybrid Prophet + XGBoost model.
|
||||||
@@ -67,6 +68,7 @@ class HybridProphetXGBoost:
|
|||||||
df: Training data (must have 'ds', 'y' and regressor columns)
|
df: Training data (must have 'ds', 'y' and regressor columns)
|
||||||
job_id: Training job identifier
|
job_id: Training job identifier
|
||||||
validation_split: Fraction of data for validation
|
validation_split: Fraction of data for validation
|
||||||
|
session: Optional database session (uses parent session if provided to avoid nested sessions)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with model metadata and performance metrics
|
Dictionary with model metadata and performance metrics
|
||||||
@@ -80,11 +82,13 @@ class HybridProphetXGBoost:
|
|||||||
|
|
||||||
# Step 1: Train Prophet model (base forecaster)
|
# Step 1: Train Prophet model (base forecaster)
|
||||||
logger.info("Step 1: Training Prophet base model")
|
logger.info("Step 1: Training Prophet base model")
|
||||||
|
# ✅ FIX: Pass session to prophet_manager to avoid nested session issues
|
||||||
prophet_result = await self.prophet_manager.train_bakery_model(
|
prophet_result = await self.prophet_manager.train_bakery_model(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
inventory_product_id=inventory_product_id,
|
inventory_product_id=inventory_product_id,
|
||||||
df=df.copy(),
|
df=df.copy(),
|
||||||
job_id=job_id
|
job_id=job_id,
|
||||||
|
session=session
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prophet_model_data = prophet_result
|
self.prophet_model_data = prophet_result
|
||||||
|
|||||||
@@ -94,7 +94,8 @@ class BakeryProphetManager:
|
|||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
product_category: 'ProductCategory' = None,
|
product_category: 'ProductCategory' = None,
|
||||||
category_hyperparameters: Dict[str, Any] = None) -> Dict[str, Any]:
|
category_hyperparameters: Dict[str, Any] = None,
|
||||||
|
session = None) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Train a Prophet model with automatic hyperparameter optimization and distributed locking.
|
Train a Prophet model with automatic hyperparameter optimization and distributed locking.
|
||||||
|
|
||||||
@@ -105,6 +106,7 @@ class BakeryProphetManager:
|
|||||||
job_id: Training job identifier
|
job_id: Training job identifier
|
||||||
product_category: Optional product category for category-specific settings
|
product_category: Optional product category for category-specific settings
|
||||||
category_hyperparameters: Optional category-specific Prophet hyperparameters
|
category_hyperparameters: Optional category-specific Prophet hyperparameters
|
||||||
|
session: Optional database session (uses parent session if provided to avoid nested sessions)
|
||||||
"""
|
"""
|
||||||
# Check disk space before starting training
|
# Check disk space before starting training
|
||||||
has_space, free_gb, total_gb, used_percent = check_disk_space('/tmp', min_free_gb=0.5)
|
has_space, free_gb, total_gb, used_percent = check_disk_space('/tmp', min_free_gb=0.5)
|
||||||
@@ -116,141 +118,155 @@ class BakeryProphetManager:
|
|||||||
# Acquire distributed lock to prevent concurrent training of same product
|
# Acquire distributed lock to prevent concurrent training of same product
|
||||||
lock = get_training_lock(tenant_id, inventory_product_id, use_advisory=True)
|
lock = get_training_lock(tenant_id, inventory_product_id, use_advisory=True)
|
||||||
|
|
||||||
try:
|
# Use provided session or create new one if not provided
|
||||||
async with self.database_manager.get_session() as session:
|
use_parent_session = session is not None
|
||||||
async with lock.acquire(session):
|
|
||||||
logger.info(f"Training optimized bakery model for {inventory_product_id} (lock acquired)")
|
|
||||||
|
|
||||||
# Validate input data
|
async def _train_with_lock(db_session):
|
||||||
await self._validate_training_data(df, inventory_product_id)
|
"""Inner function to perform training with lock"""
|
||||||
|
async with lock.acquire(db_session):
|
||||||
|
logger.info(f"Training optimized bakery model for {inventory_product_id} (lock acquired)")
|
||||||
|
|
||||||
# Prepare data for Prophet
|
# Validate input data
|
||||||
prophet_data = await self._prepare_prophet_data(df)
|
await self._validate_training_data(df, inventory_product_id)
|
||||||
|
|
||||||
# Get regressor columns
|
# Prepare data for Prophet
|
||||||
regressor_columns = self._extract_regressor_columns(prophet_data)
|
prophet_data = await self._prepare_prophet_data(df)
|
||||||
|
|
||||||
# Use category-specific hyperparameters if provided, otherwise optimize
|
# Get regressor columns
|
||||||
if category_hyperparameters:
|
regressor_columns = self._extract_regressor_columns(prophet_data)
|
||||||
logger.info(f"Using category-specific hyperparameters for {inventory_product_id} (category: {product_category.value if product_category else 'unknown'})")
|
|
||||||
best_params = category_hyperparameters.copy()
|
|
||||||
use_optimized = False # Not optimized, but category-specific
|
|
||||||
else:
|
|
||||||
# Automatically optimize hyperparameters
|
|
||||||
logger.info(f"Optimizing hyperparameters for {inventory_product_id}...")
|
|
||||||
try:
|
|
||||||
best_params = await self._optimize_hyperparameters(prophet_data, inventory_product_id, regressor_columns)
|
|
||||||
use_optimized = True
|
|
||||||
except Exception as opt_error:
|
|
||||||
logger.warning(f"Hyperparameter optimization failed for {inventory_product_id}: {opt_error}")
|
|
||||||
logger.warning("Falling back to default Prophet parameters")
|
|
||||||
# Use conservative default parameters
|
|
||||||
best_params = {
|
|
||||||
'changepoint_prior_scale': 0.05,
|
|
||||||
'seasonality_prior_scale': 10.0,
|
|
||||||
'holidays_prior_scale': 10.0,
|
|
||||||
'changepoint_range': 0.8,
|
|
||||||
'seasonality_mode': 'additive',
|
|
||||||
'daily_seasonality': False,
|
|
||||||
'weekly_seasonality': True,
|
|
||||||
'yearly_seasonality': len(prophet_data) > 365,
|
|
||||||
'uncertainty_samples': 0 # Disable uncertainty sampling to avoid cmdstan
|
|
||||||
}
|
|
||||||
use_optimized = False
|
|
||||||
|
|
||||||
# Create optimized Prophet model
|
# Use category-specific hyperparameters if provided, otherwise optimize
|
||||||
model = self._create_optimized_prophet_model(best_params, regressor_columns)
|
if category_hyperparameters:
|
||||||
|
logger.info(f"Using category-specific hyperparameters for {inventory_product_id} (category: {product_category.value if product_category else 'unknown'})")
|
||||||
# Add regressors to model
|
best_params = category_hyperparameters.copy()
|
||||||
for regressor in regressor_columns:
|
use_optimized = False # Not optimized, but category-specific
|
||||||
if regressor in prophet_data.columns:
|
else:
|
||||||
model.add_regressor(regressor)
|
# Automatically optimize hyperparameters
|
||||||
|
logger.info(f"Optimizing hyperparameters for {inventory_product_id}...")
|
||||||
# Set environment variable for cmdstan tmp directory
|
|
||||||
import os
|
|
||||||
tmpdir = os.environ.get('TMPDIR', '/tmp/cmdstan')
|
|
||||||
os.makedirs(tmpdir, mode=0o777, exist_ok=True)
|
|
||||||
os.environ['TMPDIR'] = tmpdir
|
|
||||||
|
|
||||||
# Verify tmp directory is writable
|
|
||||||
test_file = os.path.join(tmpdir, f'test_write_{inventory_product_id}.tmp')
|
|
||||||
try:
|
try:
|
||||||
with open(test_file, 'w') as f:
|
best_params = await self._optimize_hyperparameters(prophet_data, inventory_product_id, regressor_columns)
|
||||||
f.write('test')
|
use_optimized = True
|
||||||
os.remove(test_file)
|
except Exception as opt_error:
|
||||||
logger.debug(f"Verified {tmpdir} is writable")
|
logger.warning(f"Hyperparameter optimization failed for {inventory_product_id}: {opt_error}")
|
||||||
except Exception as e:
|
logger.warning("Falling back to default Prophet parameters")
|
||||||
logger.error(f"TMPDIR {tmpdir} is not writable: {e}")
|
# Use conservative default parameters
|
||||||
raise RuntimeError(f"Cannot write to {tmpdir}: {e}")
|
best_params = {
|
||||||
|
'changepoint_prior_scale': 0.05,
|
||||||
# Fit the model with enhanced error handling
|
'seasonality_prior_scale': 10.0,
|
||||||
try:
|
'holidays_prior_scale': 10.0,
|
||||||
logger.info(f"Starting Prophet model fit for {inventory_product_id}")
|
'changepoint_range': 0.8,
|
||||||
# ✅ FIX: Run blocking model.fit() in thread pool to avoid blocking event loop
|
'seasonality_mode': 'additive',
|
||||||
import asyncio
|
'daily_seasonality': False,
|
||||||
await asyncio.to_thread(model.fit, prophet_data)
|
'weekly_seasonality': True,
|
||||||
logger.info(f"Prophet model fit completed successfully for {inventory_product_id}")
|
'yearly_seasonality': len(prophet_data) > 365,
|
||||||
except Exception as fit_error:
|
'uncertainty_samples': 0 # Disable uncertainty sampling to avoid cmdstan
|
||||||
error_details = {
|
|
||||||
'error_type': type(fit_error).__name__,
|
|
||||||
'error_message': str(fit_error),
|
|
||||||
'errno': getattr(fit_error, 'errno', None),
|
|
||||||
'tmpdir': tmpdir,
|
|
||||||
'disk_space': check_disk_space(tmpdir, 0)
|
|
||||||
}
|
}
|
||||||
logger.error(f"Prophet model fit failed for {inventory_product_id}: {error_details}")
|
use_optimized = False
|
||||||
raise RuntimeError(f"Prophet training failed: {error_details['error_message']}") from fit_error
|
|
||||||
|
|
||||||
# Calculate enhanced training metrics first
|
# Create optimized Prophet model
|
||||||
training_metrics = await self._calculate_training_metrics(model, prophet_data, best_params)
|
model = self._create_optimized_prophet_model(best_params, regressor_columns)
|
||||||
|
|
||||||
# Store model and metrics - Generate proper UUID for model_id
|
# Add regressors to model
|
||||||
model_id = str(uuid.uuid4())
|
for regressor in regressor_columns:
|
||||||
model_path = await self._store_model(
|
if regressor in prophet_data.columns:
|
||||||
tenant_id, inventory_product_id, model, model_id, prophet_data, regressor_columns, best_params, training_metrics
|
model.add_regressor(regressor)
|
||||||
)
|
|
||||||
|
|
||||||
# Return same format as before, but with optimization info
|
# Set environment variable for cmdstan tmp directory
|
||||||
# Ensure hyperparameters are JSON-serializable
|
import os
|
||||||
def _serialize_hyperparameters(params):
|
tmpdir = os.environ.get('TMPDIR', '/tmp/cmdstan')
|
||||||
"""Helper to ensure hyperparameters are JSON serializable"""
|
os.makedirs(tmpdir, mode=0o777, exist_ok=True)
|
||||||
if not params:
|
os.environ['TMPDIR'] = tmpdir
|
||||||
return {}
|
|
||||||
safe_params = {}
|
|
||||||
for k, v in params.items():
|
|
||||||
try:
|
|
||||||
if isinstance(v, (int, float, str, bool, type(None))):
|
|
||||||
safe_params[k] = v
|
|
||||||
elif hasattr(v, 'item'): # numpy scalars
|
|
||||||
safe_params[k] = v.item()
|
|
||||||
elif isinstance(v, (list, tuple)):
|
|
||||||
safe_params[k] = [x.item() if hasattr(x, 'item') else x for x in v]
|
|
||||||
else:
|
|
||||||
safe_params[k] = float(v) if isinstance(v, (np.integer, np.floating)) else str(v)
|
|
||||||
except:
|
|
||||||
safe_params[k] = str(v) # fallback to string conversion
|
|
||||||
return safe_params
|
|
||||||
|
|
||||||
model_info = {
|
# Verify tmp directory is writable
|
||||||
"model_id": model_id,
|
test_file = os.path.join(tmpdir, f'test_write_{inventory_product_id}.tmp')
|
||||||
"model_path": model_path,
|
try:
|
||||||
"type": "prophet_optimized",
|
with open(test_file, 'w') as f:
|
||||||
"training_samples": len(prophet_data),
|
f.write('test')
|
||||||
"features": regressor_columns,
|
os.remove(test_file)
|
||||||
"hyperparameters": _serialize_hyperparameters(best_params),
|
logger.debug(f"Verified {tmpdir} is writable")
|
||||||
"training_metrics": training_metrics,
|
except Exception as e:
|
||||||
"product_category": product_category.value if product_category else "unknown",
|
logger.error(f"TMPDIR {tmpdir} is not writable: {e}")
|
||||||
"trained_at": datetime.now().isoformat(),
|
raise RuntimeError(f"Cannot write to {tmpdir}: {e}")
|
||||||
"data_period": {
|
|
||||||
"start_date": pd.Timestamp(prophet_data['ds'].min()).isoformat(),
|
# Fit the model with enhanced error handling
|
||||||
"end_date": pd.Timestamp(prophet_data['ds'].max()).isoformat(),
|
try:
|
||||||
"total_days": len(prophet_data)
|
logger.info(f"Starting Prophet model fit for {inventory_product_id}")
|
||||||
}
|
# ✅ FIX: Run blocking model.fit() in thread pool to avoid blocking event loop
|
||||||
|
import asyncio
|
||||||
|
await asyncio.to_thread(model.fit, prophet_data)
|
||||||
|
logger.info(f"Prophet model fit completed successfully for {inventory_product_id}")
|
||||||
|
except Exception as fit_error:
|
||||||
|
error_details = {
|
||||||
|
'error_type': type(fit_error).__name__,
|
||||||
|
'error_message': str(fit_error),
|
||||||
|
'errno': getattr(fit_error, 'errno', None),
|
||||||
|
'tmpdir': tmpdir,
|
||||||
|
'disk_space': check_disk_space(tmpdir, 0)
|
||||||
}
|
}
|
||||||
|
logger.error(f"Prophet model fit failed for {inventory_product_id}: {error_details}")
|
||||||
|
raise RuntimeError(f"Prophet training failed: {error_details['error_message']}") from fit_error
|
||||||
|
|
||||||
logger.info(f"Optimized model trained successfully for {inventory_product_id}. "
|
# Calculate enhanced training metrics first
|
||||||
f"MAPE: {training_metrics.get('optimized_mape', 'N/A')}%")
|
training_metrics = await self._calculate_training_metrics(model, prophet_data, best_params)
|
||||||
return model_info
|
|
||||||
|
# Store model and metrics - Generate proper UUID for model_id
|
||||||
|
model_id = str(uuid.uuid4())
|
||||||
|
model_path = await self._store_model(
|
||||||
|
tenant_id, inventory_product_id, model, model_id, prophet_data, regressor_columns, best_params, training_metrics
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return same format as before, but with optimization info
|
||||||
|
# Ensure hyperparameters are JSON-serializable
|
||||||
|
def _serialize_hyperparameters(params):
|
||||||
|
"""Helper to ensure hyperparameters are JSON serializable"""
|
||||||
|
if not params:
|
||||||
|
return {}
|
||||||
|
safe_params = {}
|
||||||
|
for k, v in params.items():
|
||||||
|
try:
|
||||||
|
if isinstance(v, (int, float, str, bool, type(None))):
|
||||||
|
safe_params[k] = v
|
||||||
|
elif hasattr(v, 'item'): # numpy scalars
|
||||||
|
safe_params[k] = v.item()
|
||||||
|
elif isinstance(v, (list, tuple)):
|
||||||
|
safe_params[k] = [x.item() if hasattr(x, 'item') else x for x in v]
|
||||||
|
else:
|
||||||
|
safe_params[k] = float(v) if isinstance(v, (np.integer, np.floating)) else str(v)
|
||||||
|
except:
|
||||||
|
safe_params[k] = str(v) # fallback to string conversion
|
||||||
|
return safe_params
|
||||||
|
|
||||||
|
model_info = {
|
||||||
|
"model_id": model_id,
|
||||||
|
"model_path": model_path,
|
||||||
|
"type": "prophet_optimized",
|
||||||
|
"training_samples": len(prophet_data),
|
||||||
|
"features": regressor_columns,
|
||||||
|
"hyperparameters": _serialize_hyperparameters(best_params),
|
||||||
|
"training_metrics": training_metrics,
|
||||||
|
"product_category": product_category.value if product_category else "unknown",
|
||||||
|
"trained_at": datetime.now().isoformat(),
|
||||||
|
"data_period": {
|
||||||
|
"start_date": pd.Timestamp(prophet_data['ds'].min()).isoformat(),
|
||||||
|
"end_date": pd.Timestamp(prophet_data['ds'].max()).isoformat(),
|
||||||
|
"total_days": len(prophet_data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Optimized model trained successfully for {inventory_product_id}. "
|
||||||
|
f"MAPE: {training_metrics.get('optimized_mape', 'N/A')}%")
|
||||||
|
return model_info
|
||||||
|
|
||||||
|
try:
|
||||||
|
# ✅ FIX: Use parent session if provided, otherwise create new one
|
||||||
|
# This prevents nested session issues and database deadlocks
|
||||||
|
if use_parent_session:
|
||||||
|
logger.debug(f"Using parent session for training {inventory_product_id}")
|
||||||
|
return await _train_with_lock(session)
|
||||||
|
else:
|
||||||
|
logger.debug(f"Creating new session for training {inventory_product_id}")
|
||||||
|
async with self.database_manager.get_session() as new_session:
|
||||||
|
return await _train_with_lock(new_session)
|
||||||
|
|
||||||
except LockAcquisitionError as e:
|
except LockAcquisitionError as e:
|
||||||
logger.warning(f"Could not acquire lock for {inventory_product_id}: {e}")
|
logger.warning(f"Could not acquire lock for {inventory_product_id}: {e}")
|
||||||
|
|||||||
@@ -168,7 +168,12 @@ class EnhancedBakeryMLTrainer:
|
|||||||
await repos['training_log'].update_log_progress(
|
await repos['training_log'].update_log_progress(
|
||||||
job_id, 5, "data_processing", "running"
|
job_id, 5, "data_processing", "running"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ✅ FIX: Flush the session to ensure the update is committed before proceeding
|
||||||
|
# This prevents deadlocks when training methods need to acquire locks
|
||||||
|
await db_session.flush()
|
||||||
|
logger.debug("Flushed session after initial progress update")
|
||||||
|
|
||||||
# Process data for each product using enhanced processor
|
# Process data for each product using enhanced processor
|
||||||
logger.info("Processing data using enhanced processor")
|
logger.info("Processing data using enhanced processor")
|
||||||
processed_data = await self._process_all_products_enhanced(
|
processed_data = await self._process_all_products_enhanced(
|
||||||
@@ -221,8 +226,9 @@ class EnhancedBakeryMLTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Train all models in parallel (without DB writes to avoid session conflicts)
|
# Train all models in parallel (without DB writes to avoid session conflicts)
|
||||||
|
# ✅ FIX: Pass db_session to prevent nested session issues and deadlocks
|
||||||
training_results = await self._train_all_models_enhanced(
|
training_results = await self._train_all_models_enhanced(
|
||||||
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories
|
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories, db_session
|
||||||
)
|
)
|
||||||
|
|
||||||
# Write all training results to database sequentially (after parallel training completes)
|
# Write all training results to database sequentially (after parallel training completes)
|
||||||
@@ -493,12 +499,16 @@ class EnhancedBakeryMLTrainer:
|
|||||||
job_id: str,
|
job_id: str,
|
||||||
repos: Dict,
|
repos: Dict,
|
||||||
progress_tracker: ParallelProductProgressTracker,
|
progress_tracker: ParallelProductProgressTracker,
|
||||||
product_category: ProductCategory = ProductCategory.UNKNOWN) -> tuple[str, Dict[str, Any]]:
|
product_category: ProductCategory = ProductCategory.UNKNOWN,
|
||||||
|
session = None) -> tuple[str, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Train a single product model - used for parallel execution with progress aggregation.
|
Train a single product model - used for parallel execution with progress aggregation.
|
||||||
|
|
||||||
Note: This method ONLY trains the model and collects results. Database writes happen
|
Note: This method ONLY trains the model and collects results. Database writes happen
|
||||||
separately to avoid concurrent session conflicts.
|
separately to avoid concurrent session conflicts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session to use for training (prevents nested session issues)
|
||||||
"""
|
"""
|
||||||
product_start_time = time.time()
|
product_start_time = time.time()
|
||||||
|
|
||||||
@@ -539,13 +549,15 @@ class EnhancedBakeryMLTrainer:
|
|||||||
category=product_category.value)
|
category=product_category.value)
|
||||||
|
|
||||||
# Train the selected model
|
# Train the selected model
|
||||||
|
# ✅ FIX: Pass session to training methods to avoid nested session issues
|
||||||
if model_type == "hybrid":
|
if model_type == "hybrid":
|
||||||
# Train hybrid Prophet + XGBoost model
|
# Train hybrid Prophet + XGBoost model
|
||||||
model_info = await self.hybrid_trainer.train_hybrid_model(
|
model_info = await self.hybrid_trainer.train_hybrid_model(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
inventory_product_id=inventory_product_id,
|
inventory_product_id=inventory_product_id,
|
||||||
df=product_data,
|
df=product_data,
|
||||||
job_id=job_id
|
job_id=job_id,
|
||||||
|
session=session
|
||||||
)
|
)
|
||||||
model_info['model_type'] = 'hybrid_prophet_xgboost'
|
model_info['model_type'] = 'hybrid_prophet_xgboost'
|
||||||
else:
|
else:
|
||||||
@@ -556,7 +568,8 @@ class EnhancedBakeryMLTrainer:
|
|||||||
df=product_data,
|
df=product_data,
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
product_category=product_category,
|
product_category=product_category,
|
||||||
category_hyperparameters=category_characteristics.get('prophet_params', {})
|
category_hyperparameters=category_characteristics.get('prophet_params', {}),
|
||||||
|
session=session
|
||||||
)
|
)
|
||||||
model_info['model_type'] = 'prophet_optimized'
|
model_info['model_type'] = 'prophet_optimized'
|
||||||
|
|
||||||
@@ -620,12 +633,19 @@ class EnhancedBakeryMLTrainer:
|
|||||||
job_id: str,
|
job_id: str,
|
||||||
repos: Dict,
|
repos: Dict,
|
||||||
progress_tracker: ParallelProductProgressTracker,
|
progress_tracker: ParallelProductProgressTracker,
|
||||||
product_categories: Dict[str, ProductCategory] = None) -> Dict[str, Any]:
|
product_categories: Dict[str, ProductCategory] = None,
|
||||||
"""Train models with throttled parallel execution and progress tracking"""
|
session = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Train models with throttled parallel execution and progress tracking
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session to pass to training methods (prevents nested session issues)
|
||||||
|
"""
|
||||||
total_products = len(processed_data)
|
total_products = len(processed_data)
|
||||||
logger.info(f"Starting throttled parallel training for {total_products} products")
|
logger.info(f"Starting throttled parallel training for {total_products} products")
|
||||||
|
|
||||||
# Create training tasks for all products
|
# Create training tasks for all products
|
||||||
|
# ✅ FIX: Pass session to prevent nested session issues and deadlocks
|
||||||
training_tasks = [
|
training_tasks = [
|
||||||
self._train_single_product(
|
self._train_single_product(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@@ -634,7 +654,8 @@ class EnhancedBakeryMLTrainer:
|
|||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
repos=repos,
|
repos=repos,
|
||||||
progress_tracker=progress_tracker,
|
progress_tracker=progress_tracker,
|
||||||
product_category=product_categories.get(inventory_product_id, ProductCategory.UNKNOWN) if product_categories else ProductCategory.UNKNOWN
|
product_category=product_categories.get(inventory_product_id, ProductCategory.UNKNOWN) if product_categories else ProductCategory.UNKNOWN,
|
||||||
|
session=session
|
||||||
)
|
)
|
||||||
for inventory_product_id, product_data in processed_data.items()
|
for inventory_product_id, product_data in processed_data.items()
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user