Fix training job concurrent database session conflicts

Root Cause:
- Multiple parallel training tasks (3 at a time) were sharing the same database session
- This caused SQLAlchemy session state conflicts: "Session is already flushing" and "rollback() is already in progress"
- Additionally, duplicate model records were being created by both trainer and training_service

Fixes:
1. Separated model training from database writes:
   - Training happens in parallel (CPU-intensive)
   - Database writes happen sequentially after training completes
   - This eliminates concurrent session access

2. Removed duplicate database writes:
   - Trainer now writes all model records sequentially after parallel training
   - Training service now retrieves models instead of creating duplicates
   - Performance metrics are also created by trainer (no duplicates)

3. Added proper data flow:
   - _train_single_product: Only trains models, stores results
   - _write_training_results_to_database: Sequential DB writes after training
   - _store_trained_models: Changed to retrieve existing models
   - _create_performance_metrics: Changed to verify existing metrics

Benefits:
- Eliminates database session conflicts
- Prevents duplicate model records
- Maintains parallel training performance
- Ensures data consistency

Files Modified:
- services/training/app/ml/trainer.py
- services/training/app/services/training_service.py

Resolves: Onboarding training job database session conflicts
This commit is contained in:
Claude
2025-11-05 12:41:42 +00:00
parent 394ad3aea4
commit 799e7dbaeb
2 changed files with 157 additions and 111 deletions

View File

@@ -217,10 +217,17 @@ class EnhancedBakeryMLTrainer:
total_products=len(processed_data)
)
# Train all models in parallel (without DB writes to avoid session conflicts)
training_results = await self._train_all_models_enhanced(
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories
)
# Write all training results to database sequentially (after parallel training completes)
logger.info("Writing training results to database sequentially")
training_results = await self._write_training_results_to_database(
tenant_id, job_id, training_results, repos
)
# Calculate overall training summary with enhanced metrics
summary = await self._calculate_enhanced_training_summary(
training_results, repos, tenant_id
@@ -482,7 +489,12 @@ class EnhancedBakeryMLTrainer:
repos: Dict,
progress_tracker: ParallelProductProgressTracker,
product_category: ProductCategory = ProductCategory.UNKNOWN) -> tuple[str, Dict[str, Any]]:
"""Train a single product model - used for parallel execution with progress aggregation"""
"""
Train a single product model - used for parallel execution with progress aggregation.
Note: This method ONLY trains the model and collects results. Database writes happen
separately to avoid concurrent session conflicts.
"""
product_start_time = time.time()
try:
@@ -497,7 +509,9 @@ class EnhancedBakeryMLTrainer:
'reason': 'insufficient_data',
'data_points': len(product_data),
'min_required': settings.MIN_TRAINING_DATA_DAYS,
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}',
'product_data': product_data, # Store for later DB writes
'product_category': product_category
}
logger.warning("Skipping product due to insufficient data",
inventory_product_id=inventory_product_id,
@@ -557,30 +571,21 @@ class EnhancedBakeryMLTrainer:
continue
model_info['training_metrics'] = filtered_metrics
# Store model record using repository
model_record = await self._create_model_record(
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
)
# Create performance metrics record
if model_info.get('training_metrics'):
await self._create_performance_metrics(
repos, model_record.id if model_record else None,
tenant_id, inventory_product_id, model_info['training_metrics']
)
# IMPORTANT: Do NOT write to database here - causes concurrent session conflicts
# Store all info needed for later DB writes (done sequentially after all training completes)
result = {
'status': 'success',
'model_info': model_info,
'model_record_id': str(model_record.id) if model_record else None,
'data_points': len(product_data),
'training_time_seconds': time.time() - product_start_time,
'trained_at': datetime.now().isoformat()
'trained_at': datetime.now().isoformat(),
# Store data needed for DB writes later
'product_data': product_data,
'product_category': product_category
}
logger.info("Successfully trained model",
inventory_product_id=inventory_product_id,
model_record_id=model_record.id if model_record else None)
logger.info("Successfully trained model (DB writes deferred)",
inventory_product_id=inventory_product_id)
# Report completion to progress tracker (emits Event 3: product_completed)
await progress_tracker.mark_product_completed(inventory_product_id)
@@ -676,7 +681,74 @@ class EnhancedBakeryMLTrainer:
logger.info(f"Throttled parallel training completed: {len(training_results)} products processed")
return training_results
async def _write_training_results_to_database(self,
tenant_id: str,
job_id: str,
training_results: Dict[str, Any],
repos: Dict) -> Dict[str, Any]:
"""
Write training results to database sequentially to avoid concurrent session conflicts.
This method is called AFTER all parallel training is complete.
"""
logger.info("Writing training results to database sequentially",
total_products=len(training_results))
updated_results = {}
for product_id, result in training_results.items():
try:
if result.get('status') == 'success':
model_info = result.get('model_info')
product_data = result.get('product_data')
if model_info and product_data is not None:
# Create model record
model_record = await self._create_model_record(
repos, tenant_id, product_id, model_info, job_id, product_data
)
# Create performance metrics
if model_info.get('training_metrics') and model_record:
await self._create_performance_metrics(
repos, model_record.id,
tenant_id, product_id, model_info['training_metrics']
)
# Update result with model_record_id
result['model_record_id'] = str(model_record.id) if model_record else None
logger.info("Database records created successfully",
inventory_product_id=product_id,
model_record_id=model_record.id if model_record else None)
# Remove product_data from result to avoid serialization issues
if 'product_data' in result:
del result['product_data']
if 'product_category' in result:
del result['product_category']
updated_results[product_id] = result
except Exception as e:
logger.error("Failed to write database records for product",
inventory_product_id=product_id,
error=str(e))
# Keep the training result but mark that DB write failed
result['db_write_error'] = str(e)
if 'product_data' in result:
del result['product_data']
if 'product_category' in result:
del result['product_category']
updated_results[product_id] = result
logger.info("Database writes completed",
successful_writes=len([r for r in updated_results.values() if 'model_record_id' in r]),
total_products=len(updated_results))
return updated_results
async def _create_model_record(self,
repos: Dict,
tenant_id: str,