Fix training hang by wrapping blocking ML operations in thread pool
Root Cause: Training process was stuck at 40% because blocking synchronous ML operations (model.fit(), model.predict(), study.optimize()) were freezing the asyncio event loop, preventing RabbitMQ heartbeats, WebSocket communication, and progress updates. Changes: 1. prophet_manager.py: - Wrapped model.fit() at line 189 with asyncio.to_thread() - Wrapped study.optimize() at line 453 with asyncio.to_thread() 2. hybrid_trainer.py: - Made _train_xgboost() async and wrapped model.fit() with asyncio.to_thread() - Made _evaluate_hybrid_model() async and wrapped predict() calls - Fixed predict() method to wrap blocking predict() calls Impact: - Event loop no longer blocks during ML training - RabbitMQ heartbeats continue during training - WebSocket progress updates work correctly - Training can now complete successfully Fixes: Training hang at 40% during onboarding phase
This commit is contained in:
@@ -125,14 +125,14 @@ class HybridProphetXGBoost:
|
||||
|
||||
# Step 7: Train XGBoost on residuals
|
||||
logger.info("Step 4: Training XGBoost on residuals")
|
||||
self.xgb_model = self._train_xgboost(
|
||||
self.xgb_model = await self._train_xgboost(
|
||||
X_train, train_residuals,
|
||||
X_val, val_residuals
|
||||
)
|
||||
|
||||
# Step 8: Evaluate hybrid model
|
||||
logger.info("Step 5: Evaluating hybrid model performance")
|
||||
metrics = self._evaluate_hybrid_model(
|
||||
metrics = await self._evaluate_hybrid_model(
|
||||
train_df, val_df,
|
||||
train_prophet_pred, val_prophet_pred,
|
||||
prophet_result
|
||||
@@ -238,7 +238,7 @@ class HybridProphetXGBoost:
|
||||
|
||||
return forecast['yhat'].values
|
||||
|
||||
def _train_xgboost(
|
||||
async def _train_xgboost(
|
||||
self,
|
||||
X_train: np.ndarray,
|
||||
y_train: np.ndarray,
|
||||
@@ -275,8 +275,10 @@ class HybridProphetXGBoost:
|
||||
# Initialize model
|
||||
model = xgb.XGBRegressor(**params)
|
||||
|
||||
# Train with early stopping
|
||||
model.fit(
|
||||
# ✅ FIX: Run blocking model.fit() in thread pool to avoid blocking event loop
|
||||
import asyncio
|
||||
await asyncio.to_thread(
|
||||
model.fit,
|
||||
X_train, y_train,
|
||||
eval_set=[(X_val, y_val)],
|
||||
early_stopping_rounds=10,
|
||||
@@ -290,7 +292,7 @@ class HybridProphetXGBoost:
|
||||
|
||||
return model
|
||||
|
||||
def _evaluate_hybrid_model(
|
||||
async def _evaluate_hybrid_model(
|
||||
self,
|
||||
train_df: pd.DataFrame,
|
||||
val_df: pd.DataFrame,
|
||||
@@ -319,8 +321,10 @@ class HybridProphetXGBoost:
|
||||
X_train = train_df[self.feature_columns].values
|
||||
X_val = val_df[self.feature_columns].values
|
||||
|
||||
train_xgb_pred = self.xgb_model.predict(X_train)
|
||||
val_xgb_pred = self.xgb_model.predict(X_val)
|
||||
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
|
||||
import asyncio
|
||||
train_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_train)
|
||||
val_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_val)
|
||||
|
||||
# Hybrid predictions = Prophet + XGBoost residual correction
|
||||
train_hybrid_pred = train_prophet_pred + train_xgb_pred
|
||||
@@ -420,7 +424,9 @@ class HybridProphetXGBoost:
|
||||
"""
|
||||
# Step 1: Get Prophet predictions
|
||||
prophet_model = model_data['prophet_model']
|
||||
prophet_forecast = prophet_model.predict(future_df)
|
||||
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
|
||||
import asyncio
|
||||
prophet_forecast = await asyncio.to_thread(prophet_model.predict, future_df)
|
||||
|
||||
# Step 2: Prepare features for XGBoost
|
||||
future_enhanced = self._prepare_xgboost_features(future_df)
|
||||
@@ -429,7 +435,8 @@ class HybridProphetXGBoost:
|
||||
xgb_model = model_data['xgboost_model']
|
||||
feature_columns = model_data['feature_columns']
|
||||
X_future = future_enhanced[feature_columns].values
|
||||
xgb_pred = xgb_model.predict(X_future)
|
||||
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
|
||||
xgb_pred = await asyncio.to_thread(xgb_model.predict, X_future)
|
||||
|
||||
# Step 4: Combine predictions
|
||||
hybrid_pred = prophet_forecast['yhat'].values + xgb_pred
|
||||
|
||||
@@ -186,7 +186,9 @@ class BakeryProphetManager:
|
||||
# Fit the model with enhanced error handling
|
||||
try:
|
||||
logger.info(f"Starting Prophet model fit for {inventory_product_id}")
|
||||
model.fit(prophet_data)
|
||||
# ✅ FIX: Run blocking model.fit() in thread pool to avoid blocking event loop
|
||||
import asyncio
|
||||
await asyncio.to_thread(model.fit, prophet_data)
|
||||
logger.info(f"Prophet model fit completed successfully for {inventory_product_id}")
|
||||
except Exception as fit_error:
|
||||
error_details = {
|
||||
@@ -450,7 +452,15 @@ class BakeryProphetManager:
|
||||
direction='minimize',
|
||||
sampler=optuna.samplers.TPESampler(seed=product_seed)
|
||||
)
|
||||
study.optimize(objective, n_trials=n_trials, timeout=const.OPTUNA_TIMEOUT_SECONDS, show_progress_bar=False)
|
||||
# ✅ FIX: Run blocking study.optimize() in thread pool to avoid blocking event loop
|
||||
import asyncio
|
||||
await asyncio.to_thread(
|
||||
study.optimize,
|
||||
objective,
|
||||
n_trials=n_trials,
|
||||
timeout=const.OPTUNA_TIMEOUT_SECONDS,
|
||||
show_progress_bar=False
|
||||
)
|
||||
|
||||
# Return best parameters
|
||||
best_params = study.best_params
|
||||
|
||||
Reference in New Issue
Block a user