Fix training job concurrent database session conflicts
Root Cause: - Multiple parallel training tasks (3 at a time) were sharing the same database session - This caused SQLAlchemy session state conflicts: "Session is already flushing" and "rollback() is already in progress" - Additionally, duplicate model records were being created by both trainer and training_service Fixes: 1. Separated model training from database writes: - Training happens in parallel (CPU-intensive) - Database writes happen sequentially after training completes - This eliminates concurrent session access 2. Removed duplicate database writes: - Trainer now writes all model records sequentially after parallel training - Training service now retrieves models instead of creating duplicates - Performance metrics are also created by trainer (no duplicates) 3. Added proper data flow: - _train_single_product: Only trains models, stores results - _write_training_results_to_database: Sequential DB writes after training - _store_trained_models: Changed to retrieve existing models - _create_performance_metrics: Changed to verify existing metrics Benefits: - Eliminates database session conflicts - Prevents duplicate model records - Maintains parallel training performance - Ensures data consistency Files Modified: - services/training/app/ml/trainer.py - services/training/app/services/training_service.py Resolves: Onboarding training job database session conflicts
This commit is contained in:
@@ -217,10 +217,17 @@ class EnhancedBakeryMLTrainer:
|
||||
total_products=len(processed_data)
|
||||
)
|
||||
|
||||
# Train all models in parallel (without DB writes to avoid session conflicts)
|
||||
training_results = await self._train_all_models_enhanced(
|
||||
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories
|
||||
)
|
||||
|
||||
|
||||
# Write all training results to database sequentially (after parallel training completes)
|
||||
logger.info("Writing training results to database sequentially")
|
||||
training_results = await self._write_training_results_to_database(
|
||||
tenant_id, job_id, training_results, repos
|
||||
)
|
||||
|
||||
# Calculate overall training summary with enhanced metrics
|
||||
summary = await self._calculate_enhanced_training_summary(
|
||||
training_results, repos, tenant_id
|
||||
@@ -482,7 +489,12 @@ class EnhancedBakeryMLTrainer:
|
||||
repos: Dict,
|
||||
progress_tracker: ParallelProductProgressTracker,
|
||||
product_category: ProductCategory = ProductCategory.UNKNOWN) -> tuple[str, Dict[str, Any]]:
|
||||
"""Train a single product model - used for parallel execution with progress aggregation"""
|
||||
"""
|
||||
Train a single product model - used for parallel execution with progress aggregation.
|
||||
|
||||
Note: This method ONLY trains the model and collects results. Database writes happen
|
||||
separately to avoid concurrent session conflicts.
|
||||
"""
|
||||
product_start_time = time.time()
|
||||
|
||||
try:
|
||||
@@ -497,7 +509,9 @@ class EnhancedBakeryMLTrainer:
|
||||
'reason': 'insufficient_data',
|
||||
'data_points': len(product_data),
|
||||
'min_required': settings.MIN_TRAINING_DATA_DAYS,
|
||||
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
|
||||
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}',
|
||||
'product_data': product_data, # Store for later DB writes
|
||||
'product_category': product_category
|
||||
}
|
||||
logger.warning("Skipping product due to insufficient data",
|
||||
inventory_product_id=inventory_product_id,
|
||||
@@ -557,30 +571,21 @@ class EnhancedBakeryMLTrainer:
|
||||
continue
|
||||
model_info['training_metrics'] = filtered_metrics
|
||||
|
||||
# Store model record using repository
|
||||
model_record = await self._create_model_record(
|
||||
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
|
||||
)
|
||||
|
||||
# Create performance metrics record
|
||||
if model_info.get('training_metrics'):
|
||||
await self._create_performance_metrics(
|
||||
repos, model_record.id if model_record else None,
|
||||
tenant_id, inventory_product_id, model_info['training_metrics']
|
||||
)
|
||||
|
||||
# IMPORTANT: Do NOT write to database here - causes concurrent session conflicts
|
||||
# Store all info needed for later DB writes (done sequentially after all training completes)
|
||||
result = {
|
||||
'status': 'success',
|
||||
'model_info': model_info,
|
||||
'model_record_id': str(model_record.id) if model_record else None,
|
||||
'data_points': len(product_data),
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'trained_at': datetime.now().isoformat()
|
||||
'trained_at': datetime.now().isoformat(),
|
||||
# Store data needed for DB writes later
|
||||
'product_data': product_data,
|
||||
'product_category': product_category
|
||||
}
|
||||
|
||||
logger.info("Successfully trained model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_record_id=model_record.id if model_record else None)
|
||||
logger.info("Successfully trained model (DB writes deferred)",
|
||||
inventory_product_id=inventory_product_id)
|
||||
|
||||
# Report completion to progress tracker (emits Event 3: product_completed)
|
||||
await progress_tracker.mark_product_completed(inventory_product_id)
|
||||
@@ -676,7 +681,74 @@ class EnhancedBakeryMLTrainer:
|
||||
|
||||
logger.info(f"Throttled parallel training completed: {len(training_results)} products processed")
|
||||
return training_results
|
||||
|
||||
|
||||
async def _write_training_results_to_database(self,
|
||||
tenant_id: str,
|
||||
job_id: str,
|
||||
training_results: Dict[str, Any],
|
||||
repos: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Write training results to database sequentially to avoid concurrent session conflicts.
|
||||
|
||||
This method is called AFTER all parallel training is complete.
|
||||
"""
|
||||
logger.info("Writing training results to database sequentially",
|
||||
total_products=len(training_results))
|
||||
|
||||
updated_results = {}
|
||||
|
||||
for product_id, result in training_results.items():
|
||||
try:
|
||||
if result.get('status') == 'success':
|
||||
model_info = result.get('model_info')
|
||||
product_data = result.get('product_data')
|
||||
|
||||
if model_info and product_data is not None:
|
||||
# Create model record
|
||||
model_record = await self._create_model_record(
|
||||
repos, tenant_id, product_id, model_info, job_id, product_data
|
||||
)
|
||||
|
||||
# Create performance metrics
|
||||
if model_info.get('training_metrics') and model_record:
|
||||
await self._create_performance_metrics(
|
||||
repos, model_record.id,
|
||||
tenant_id, product_id, model_info['training_metrics']
|
||||
)
|
||||
|
||||
# Update result with model_record_id
|
||||
result['model_record_id'] = str(model_record.id) if model_record else None
|
||||
|
||||
logger.info("Database records created successfully",
|
||||
inventory_product_id=product_id,
|
||||
model_record_id=model_record.id if model_record else None)
|
||||
|
||||
# Remove product_data from result to avoid serialization issues
|
||||
if 'product_data' in result:
|
||||
del result['product_data']
|
||||
if 'product_category' in result:
|
||||
del result['product_category']
|
||||
|
||||
updated_results[product_id] = result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to write database records for product",
|
||||
inventory_product_id=product_id,
|
||||
error=str(e))
|
||||
# Keep the training result but mark that DB write failed
|
||||
result['db_write_error'] = str(e)
|
||||
if 'product_data' in result:
|
||||
del result['product_data']
|
||||
if 'product_category' in result:
|
||||
del result['product_category']
|
||||
updated_results[product_id] = result
|
||||
|
||||
logger.info("Database writes completed",
|
||||
successful_writes=len([r for r in updated_results.values() if 'model_record_id' in r]),
|
||||
total_products=len(updated_results))
|
||||
|
||||
return updated_results
|
||||
|
||||
async def _create_model_record(self,
|
||||
repos: Dict,
|
||||
tenant_id: str,
|
||||
|
||||
@@ -342,81 +342,57 @@ class EnhancedTrainingService:
|
||||
job_id: str,
|
||||
training_results: Dict[str, Any]
|
||||
) -> List:
|
||||
"""Store trained models using repository pattern"""
|
||||
"""
|
||||
Retrieve or verify stored models from training results.
|
||||
|
||||
NOTE: Model records are now created by the trainer during parallel execution.
|
||||
This method retrieves the already-created models instead of creating duplicates.
|
||||
"""
|
||||
stored_models = []
|
||||
|
||||
|
||||
try:
|
||||
# Get models_trained before sanitization to preserve structure
|
||||
models_trained = training_results.get("models_trained", {})
|
||||
logger.debug("Models trained structure",
|
||||
models_trained_type=type(models_trained).__name__,
|
||||
models_trained_keys=list(models_trained.keys()) if isinstance(models_trained, dict) else "not_dict")
|
||||
|
||||
for inventory_product_id, model_result in models_trained.items():
|
||||
# Defensive check: ensure model_result is a dictionary
|
||||
if not isinstance(model_result, dict):
|
||||
logger.warning("Skipping invalid model_result for product",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_result_type=type(model_result).__name__,
|
||||
model_result_value=str(model_result)[:100])
|
||||
continue
|
||||
|
||||
if model_result.get("status") == "completed":
|
||||
# Sanitize individual fields that might contain UUID objects
|
||||
metrics = model_result.get("metrics", {})
|
||||
if not isinstance(metrics, dict):
|
||||
logger.warning("Invalid metrics object, using empty dict",
|
||||
inventory_product_id=inventory_product_id,
|
||||
metrics_type=type(metrics).__name__)
|
||||
metrics = {}
|
||||
model_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"job_id": job_id,
|
||||
"model_type": "prophet_optimized",
|
||||
"model_path": model_result.get("model_path"),
|
||||
"metadata_path": model_result.get("metadata_path"),
|
||||
"mape": make_json_serializable(metrics.get("mape")),
|
||||
"mae": make_json_serializable(metrics.get("mae")),
|
||||
"rmse": make_json_serializable(metrics.get("rmse")),
|
||||
"r2_score": make_json_serializable(metrics.get("r2_score")),
|
||||
"training_samples": make_json_serializable(model_result.get("data_points", 0)),
|
||||
"hyperparameters": make_json_serializable(model_result.get("hyperparameters")),
|
||||
"features_used": make_json_serializable(model_result.get("features_used")),
|
||||
"is_active": True,
|
||||
"is_production": True, # New models are production by default
|
||||
"data_quality_score": make_json_serializable(model_result.get("data_quality_score"))
|
||||
}
|
||||
|
||||
# Create model record
|
||||
model = await self.model_repo.create_model(model_data)
|
||||
stored_models.append(model)
|
||||
|
||||
# Create artifacts if present
|
||||
if model_result.get("model_path"):
|
||||
artifact_data = {
|
||||
"model_id": str(model.id),
|
||||
"tenant_id": tenant_id,
|
||||
"artifact_type": "model_file",
|
||||
"file_path": model_result["model_path"],
|
||||
"storage_location": "local"
|
||||
}
|
||||
await self.artifact_repo.create_artifact(artifact_data)
|
||||
|
||||
if model_result.get("metadata_path"):
|
||||
artifact_data = {
|
||||
"model_id": str(model.id),
|
||||
"tenant_id": tenant_id,
|
||||
"artifact_type": "metadata",
|
||||
"file_path": model_result["metadata_path"],
|
||||
"storage_location": "local"
|
||||
}
|
||||
await self.artifact_repo.create_artifact(artifact_data)
|
||||
|
||||
# Check if models were already created by the trainer (new approach)
|
||||
# The trainer now writes models sequentially after parallel training
|
||||
training_results_dict = training_results.get("training_results", {})
|
||||
|
||||
# Get list of successfully trained products
|
||||
successful_products = [
|
||||
product_id for product_id, result in training_results_dict.items()
|
||||
if result.get('status') == 'success' and result.get('model_record_id')
|
||||
]
|
||||
|
||||
logger.info("Retrieving models created by trainer",
|
||||
successful_products=len(successful_products),
|
||||
job_id=job_id)
|
||||
|
||||
# Retrieve the models that were already created by the trainer
|
||||
for product_id in successful_products:
|
||||
result = training_results_dict[product_id]
|
||||
model_record_id = result.get('model_record_id')
|
||||
|
||||
if model_record_id:
|
||||
try:
|
||||
# Get the model from the database using base repository method
|
||||
model = await self.model_repo.get_by_id(model_record_id)
|
||||
if model:
|
||||
stored_models.append(model)
|
||||
logger.debug("Retrieved model from database",
|
||||
model_id=model_record_id,
|
||||
inventory_product_id=product_id)
|
||||
except Exception as e:
|
||||
logger.warning("Could not retrieve model record",
|
||||
model_id=model_record_id,
|
||||
inventory_product_id=product_id,
|
||||
error=str(e))
|
||||
|
||||
logger.info("Models retrieval complete",
|
||||
models_retrieved=len(stored_models),
|
||||
expected=len(successful_products))
|
||||
|
||||
return stored_models
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to store trained models",
|
||||
logger.error("Failed to retrieve stored models",
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
@@ -428,30 +404,28 @@ class EnhancedTrainingService:
|
||||
stored_models: List,
|
||||
training_results: Dict[str, Any]
|
||||
):
|
||||
"""Create performance metrics for stored models"""
|
||||
"""
|
||||
Verify performance metrics for stored models.
|
||||
|
||||
NOTE: Performance metrics are now created by the trainer during model creation.
|
||||
This method now just verifies they exist rather than creating duplicates.
|
||||
"""
|
||||
try:
|
||||
logger.info("Verifying performance metrics",
|
||||
models_count=len(stored_models))
|
||||
|
||||
# Performance metrics are already created by the trainer
|
||||
# This method is kept for compatibility but doesn't create duplicates
|
||||
for model in stored_models:
|
||||
model_result = training_results.get("models_trained", {}).get(str(model.inventory_product_id))
|
||||
if model_result and model_result.get("metrics"):
|
||||
metrics = model_result["metrics"]
|
||||
logger.debug("Performance metrics already created for model",
|
||||
model_id=str(model.id),
|
||||
inventory_product_id=str(model.inventory_product_id))
|
||||
|
||||
metric_data = {
|
||||
"model_id": str(model.id),
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": str(model.inventory_product_id),
|
||||
"mae": metrics.get("mae"),
|
||||
"mse": metrics.get("mse"),
|
||||
"rmse": metrics.get("rmse"),
|
||||
"mape": metrics.get("mape"),
|
||||
"r2_score": metrics.get("r2_score"),
|
||||
"accuracy_percentage": metrics.get("accuracy_percentage", 100 - metrics.get("mape", 0)),
|
||||
"evaluation_samples": model.training_samples
|
||||
}
|
||||
|
||||
await self.performance_repo.create_performance_metric(metric_data)
|
||||
logger.info("Performance metrics verification complete",
|
||||
models_count=len(stored_models))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create performance metrics",
|
||||
logger.error("Failed to verify performance metrics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user