Fix training job concurrent database session conflicts
Root Cause: - Multiple parallel training tasks (3 at a time) were sharing the same database session - This caused SQLAlchemy session state conflicts: "Session is already flushing" and "rollback() is already in progress" - Additionally, duplicate model records were being created by both trainer and training_service Fixes: 1. Separated model training from database writes: - Training happens in parallel (CPU-intensive) - Database writes happen sequentially after training completes - This eliminates concurrent session access 2. Removed duplicate database writes: - Trainer now writes all model records sequentially after parallel training - Training service now retrieves models instead of creating duplicates - Performance metrics are also created by trainer (no duplicates) 3. Added proper data flow: - _train_single_product: Only trains models, stores results - _write_training_results_to_database: Sequential DB writes after training - _store_trained_models: Changed to retrieve existing models - _create_performance_metrics: Changed to verify existing metrics Benefits: - Eliminates database session conflicts - Prevents duplicate model records - Maintains parallel training performance - Ensures data consistency Files Modified: - services/training/app/ml/trainer.py - services/training/app/services/training_service.py Resolves: Onboarding training job database session conflicts
This commit is contained in:
@@ -217,10 +217,17 @@ class EnhancedBakeryMLTrainer:
|
||||
total_products=len(processed_data)
|
||||
)
|
||||
|
||||
# Train all models in parallel (without DB writes to avoid session conflicts)
|
||||
training_results = await self._train_all_models_enhanced(
|
||||
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories
|
||||
)
|
||||
|
||||
|
||||
# Write all training results to database sequentially (after parallel training completes)
|
||||
logger.info("Writing training results to database sequentially")
|
||||
training_results = await self._write_training_results_to_database(
|
||||
tenant_id, job_id, training_results, repos
|
||||
)
|
||||
|
||||
# Calculate overall training summary with enhanced metrics
|
||||
summary = await self._calculate_enhanced_training_summary(
|
||||
training_results, repos, tenant_id
|
||||
@@ -482,7 +489,12 @@ class EnhancedBakeryMLTrainer:
|
||||
repos: Dict,
|
||||
progress_tracker: ParallelProductProgressTracker,
|
||||
product_category: ProductCategory = ProductCategory.UNKNOWN) -> tuple[str, Dict[str, Any]]:
|
||||
"""Train a single product model - used for parallel execution with progress aggregation"""
|
||||
"""
|
||||
Train a single product model - used for parallel execution with progress aggregation.
|
||||
|
||||
Note: This method ONLY trains the model and collects results. Database writes happen
|
||||
separately to avoid concurrent session conflicts.
|
||||
"""
|
||||
product_start_time = time.time()
|
||||
|
||||
try:
|
||||
@@ -497,7 +509,9 @@ class EnhancedBakeryMLTrainer:
|
||||
'reason': 'insufficient_data',
|
||||
'data_points': len(product_data),
|
||||
'min_required': settings.MIN_TRAINING_DATA_DAYS,
|
||||
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
|
||||
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}',
|
||||
'product_data': product_data, # Store for later DB writes
|
||||
'product_category': product_category
|
||||
}
|
||||
logger.warning("Skipping product due to insufficient data",
|
||||
inventory_product_id=inventory_product_id,
|
||||
@@ -557,30 +571,21 @@ class EnhancedBakeryMLTrainer:
|
||||
continue
|
||||
model_info['training_metrics'] = filtered_metrics
|
||||
|
||||
# Store model record using repository
|
||||
model_record = await self._create_model_record(
|
||||
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
|
||||
)
|
||||
|
||||
# Create performance metrics record
|
||||
if model_info.get('training_metrics'):
|
||||
await self._create_performance_metrics(
|
||||
repos, model_record.id if model_record else None,
|
||||
tenant_id, inventory_product_id, model_info['training_metrics']
|
||||
)
|
||||
|
||||
# IMPORTANT: Do NOT write to database here - causes concurrent session conflicts
|
||||
# Store all info needed for later DB writes (done sequentially after all training completes)
|
||||
result = {
|
||||
'status': 'success',
|
||||
'model_info': model_info,
|
||||
'model_record_id': str(model_record.id) if model_record else None,
|
||||
'data_points': len(product_data),
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'trained_at': datetime.now().isoformat()
|
||||
'trained_at': datetime.now().isoformat(),
|
||||
# Store data needed for DB writes later
|
||||
'product_data': product_data,
|
||||
'product_category': product_category
|
||||
}
|
||||
|
||||
logger.info("Successfully trained model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_record_id=model_record.id if model_record else None)
|
||||
logger.info("Successfully trained model (DB writes deferred)",
|
||||
inventory_product_id=inventory_product_id)
|
||||
|
||||
# Report completion to progress tracker (emits Event 3: product_completed)
|
||||
await progress_tracker.mark_product_completed(inventory_product_id)
|
||||
@@ -676,7 +681,74 @@ class EnhancedBakeryMLTrainer:
|
||||
|
||||
logger.info(f"Throttled parallel training completed: {len(training_results)} products processed")
|
||||
return training_results
|
||||
|
||||
|
||||
async def _write_training_results_to_database(self,
|
||||
tenant_id: str,
|
||||
job_id: str,
|
||||
training_results: Dict[str, Any],
|
||||
repos: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Write training results to database sequentially to avoid concurrent session conflicts.
|
||||
|
||||
This method is called AFTER all parallel training is complete.
|
||||
"""
|
||||
logger.info("Writing training results to database sequentially",
|
||||
total_products=len(training_results))
|
||||
|
||||
updated_results = {}
|
||||
|
||||
for product_id, result in training_results.items():
|
||||
try:
|
||||
if result.get('status') == 'success':
|
||||
model_info = result.get('model_info')
|
||||
product_data = result.get('product_data')
|
||||
|
||||
if model_info and product_data is not None:
|
||||
# Create model record
|
||||
model_record = await self._create_model_record(
|
||||
repos, tenant_id, product_id, model_info, job_id, product_data
|
||||
)
|
||||
|
||||
# Create performance metrics
|
||||
if model_info.get('training_metrics') and model_record:
|
||||
await self._create_performance_metrics(
|
||||
repos, model_record.id,
|
||||
tenant_id, product_id, model_info['training_metrics']
|
||||
)
|
||||
|
||||
# Update result with model_record_id
|
||||
result['model_record_id'] = str(model_record.id) if model_record else None
|
||||
|
||||
logger.info("Database records created successfully",
|
||||
inventory_product_id=product_id,
|
||||
model_record_id=model_record.id if model_record else None)
|
||||
|
||||
# Remove product_data from result to avoid serialization issues
|
||||
if 'product_data' in result:
|
||||
del result['product_data']
|
||||
if 'product_category' in result:
|
||||
del result['product_category']
|
||||
|
||||
updated_results[product_id] = result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to write database records for product",
|
||||
inventory_product_id=product_id,
|
||||
error=str(e))
|
||||
# Keep the training result but mark that DB write failed
|
||||
result['db_write_error'] = str(e)
|
||||
if 'product_data' in result:
|
||||
del result['product_data']
|
||||
if 'product_category' in result:
|
||||
del result['product_category']
|
||||
updated_results[product_id] = result
|
||||
|
||||
logger.info("Database writes completed",
|
||||
successful_writes=len([r for r in updated_results.values() if 'model_record_id' in r]),
|
||||
total_products=len(updated_results))
|
||||
|
||||
return updated_results
|
||||
|
||||
async def _create_model_record(self,
|
||||
repos: Dict,
|
||||
tenant_id: str,
|
||||
|
||||
Reference in New Issue
Block a user