Merge pull request #1 from ualsweb/claude/fix-onboarding-training-job-011CUpkdtoMWGH7ANd33zRbm

Fix training job concurrent database session conflicts
This commit is contained in:
ualsweb
2025-11-05 13:44:07 +01:00
committed by GitHub
2 changed files with 157 additions and 111 deletions

View File

@@ -217,10 +217,17 @@ class EnhancedBakeryMLTrainer:
total_products=len(processed_data)
)
# Train all models in parallel (without DB writes to avoid session conflicts)
training_results = await self._train_all_models_enhanced(
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories
)
# Write all training results to database sequentially (after parallel training completes)
logger.info("Writing training results to database sequentially")
training_results = await self._write_training_results_to_database(
tenant_id, job_id, training_results, repos
)
# Calculate overall training summary with enhanced metrics
summary = await self._calculate_enhanced_training_summary(
training_results, repos, tenant_id
@@ -482,7 +489,12 @@ class EnhancedBakeryMLTrainer:
repos: Dict,
progress_tracker: ParallelProductProgressTracker,
product_category: ProductCategory = ProductCategory.UNKNOWN) -> tuple[str, Dict[str, Any]]:
"""Train a single product model - used for parallel execution with progress aggregation"""
"""
Train a single product model - used for parallel execution with progress aggregation.
Note: This method ONLY trains the model and collects results. Database writes happen
separately to avoid concurrent session conflicts.
"""
product_start_time = time.time()
try:
@@ -497,7 +509,9 @@ class EnhancedBakeryMLTrainer:
'reason': 'insufficient_data',
'data_points': len(product_data),
'min_required': settings.MIN_TRAINING_DATA_DAYS,
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}',
'product_data': product_data, # Store for later DB writes
'product_category': product_category
}
logger.warning("Skipping product due to insufficient data",
inventory_product_id=inventory_product_id,
@@ -557,30 +571,21 @@ class EnhancedBakeryMLTrainer:
continue
model_info['training_metrics'] = filtered_metrics
# Store model record using repository
model_record = await self._create_model_record(
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
)
# Create performance metrics record
if model_info.get('training_metrics'):
await self._create_performance_metrics(
repos, model_record.id if model_record else None,
tenant_id, inventory_product_id, model_info['training_metrics']
)
# IMPORTANT: Do NOT write to database here - causes concurrent session conflicts
# Store all info needed for later DB writes (done sequentially after all training completes)
result = {
'status': 'success',
'model_info': model_info,
'model_record_id': str(model_record.id) if model_record else None,
'data_points': len(product_data),
'training_time_seconds': time.time() - product_start_time,
'trained_at': datetime.now().isoformat()
'trained_at': datetime.now().isoformat(),
# Store data needed for DB writes later
'product_data': product_data,
'product_category': product_category
}
logger.info("Successfully trained model",
inventory_product_id=inventory_product_id,
model_record_id=model_record.id if model_record else None)
logger.info("Successfully trained model (DB writes deferred)",
inventory_product_id=inventory_product_id)
# Report completion to progress tracker (emits Event 3: product_completed)
await progress_tracker.mark_product_completed(inventory_product_id)
@@ -676,7 +681,74 @@ class EnhancedBakeryMLTrainer:
logger.info(f"Throttled parallel training completed: {len(training_results)} products processed")
return training_results
async def _write_training_results_to_database(self,
tenant_id: str,
job_id: str,
training_results: Dict[str, Any],
repos: Dict) -> Dict[str, Any]:
"""
Write training results to database sequentially to avoid concurrent session conflicts.
This method is called AFTER all parallel training is complete.
"""
logger.info("Writing training results to database sequentially",
total_products=len(training_results))
updated_results = {}
for product_id, result in training_results.items():
try:
if result.get('status') == 'success':
model_info = result.get('model_info')
product_data = result.get('product_data')
if model_info and product_data is not None:
# Create model record
model_record = await self._create_model_record(
repos, tenant_id, product_id, model_info, job_id, product_data
)
# Create performance metrics
if model_info.get('training_metrics') and model_record:
await self._create_performance_metrics(
repos, model_record.id,
tenant_id, product_id, model_info['training_metrics']
)
# Update result with model_record_id
result['model_record_id'] = str(model_record.id) if model_record else None
logger.info("Database records created successfully",
inventory_product_id=product_id,
model_record_id=model_record.id if model_record else None)
# Remove product_data from result to avoid serialization issues
if 'product_data' in result:
del result['product_data']
if 'product_category' in result:
del result['product_category']
updated_results[product_id] = result
except Exception as e:
logger.error("Failed to write database records for product",
inventory_product_id=product_id,
error=str(e))
# Keep the training result but mark that DB write failed
result['db_write_error'] = str(e)
if 'product_data' in result:
del result['product_data']
if 'product_category' in result:
del result['product_category']
updated_results[product_id] = result
logger.info("Database writes completed",
successful_writes=len([r for r in updated_results.values() if 'model_record_id' in r]),
total_products=len(updated_results))
return updated_results
async def _create_model_record(self,
repos: Dict,
tenant_id: str,

View File

@@ -342,81 +342,57 @@ class EnhancedTrainingService:
job_id: str,
training_results: Dict[str, Any]
) -> List:
"""Store trained models using repository pattern"""
"""
Retrieve or verify stored models from training results.
NOTE: Model records are now created by the trainer during parallel execution.
This method retrieves the already-created models instead of creating duplicates.
"""
stored_models = []
try:
# Get models_trained before sanitization to preserve structure
models_trained = training_results.get("models_trained", {})
logger.debug("Models trained structure",
models_trained_type=type(models_trained).__name__,
models_trained_keys=list(models_trained.keys()) if isinstance(models_trained, dict) else "not_dict")
for inventory_product_id, model_result in models_trained.items():
# Defensive check: ensure model_result is a dictionary
if not isinstance(model_result, dict):
logger.warning("Skipping invalid model_result for product",
inventory_product_id=inventory_product_id,
model_result_type=type(model_result).__name__,
model_result_value=str(model_result)[:100])
continue
if model_result.get("status") == "completed":
# Sanitize individual fields that might contain UUID objects
metrics = model_result.get("metrics", {})
if not isinstance(metrics, dict):
logger.warning("Invalid metrics object, using empty dict",
inventory_product_id=inventory_product_id,
metrics_type=type(metrics).__name__)
metrics = {}
model_data = {
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"job_id": job_id,
"model_type": "prophet_optimized",
"model_path": model_result.get("model_path"),
"metadata_path": model_result.get("metadata_path"),
"mape": make_json_serializable(metrics.get("mape")),
"mae": make_json_serializable(metrics.get("mae")),
"rmse": make_json_serializable(metrics.get("rmse")),
"r2_score": make_json_serializable(metrics.get("r2_score")),
"training_samples": make_json_serializable(model_result.get("data_points", 0)),
"hyperparameters": make_json_serializable(model_result.get("hyperparameters")),
"features_used": make_json_serializable(model_result.get("features_used")),
"is_active": True,
"is_production": True, # New models are production by default
"data_quality_score": make_json_serializable(model_result.get("data_quality_score"))
}
# Create model record
model = await self.model_repo.create_model(model_data)
stored_models.append(model)
# Create artifacts if present
if model_result.get("model_path"):
artifact_data = {
"model_id": str(model.id),
"tenant_id": tenant_id,
"artifact_type": "model_file",
"file_path": model_result["model_path"],
"storage_location": "local"
}
await self.artifact_repo.create_artifact(artifact_data)
if model_result.get("metadata_path"):
artifact_data = {
"model_id": str(model.id),
"tenant_id": tenant_id,
"artifact_type": "metadata",
"file_path": model_result["metadata_path"],
"storage_location": "local"
}
await self.artifact_repo.create_artifact(artifact_data)
# Check if models were already created by the trainer (new approach)
# The trainer now writes models sequentially after parallel training
training_results_dict = training_results.get("training_results", {})
# Get list of successfully trained products
successful_products = [
product_id for product_id, result in training_results_dict.items()
if result.get('status') == 'success' and result.get('model_record_id')
]
logger.info("Retrieving models created by trainer",
successful_products=len(successful_products),
job_id=job_id)
# Retrieve the models that were already created by the trainer
for product_id in successful_products:
result = training_results_dict[product_id]
model_record_id = result.get('model_record_id')
if model_record_id:
try:
# Get the model from the database using base repository method
model = await self.model_repo.get_by_id(model_record_id)
if model:
stored_models.append(model)
logger.debug("Retrieved model from database",
model_id=model_record_id,
inventory_product_id=product_id)
except Exception as e:
logger.warning("Could not retrieve model record",
model_id=model_record_id,
inventory_product_id=product_id,
error=str(e))
logger.info("Models retrieval complete",
models_retrieved=len(stored_models),
expected=len(successful_products))
return stored_models
except Exception as e:
logger.error("Failed to store trained models",
logger.error("Failed to retrieve stored models",
tenant_id=tenant_id,
job_id=job_id,
error=str(e))
@@ -428,30 +404,28 @@ class EnhancedTrainingService:
stored_models: List,
training_results: Dict[str, Any]
):
"""Create performance metrics for stored models"""
"""
Verify performance metrics for stored models.
NOTE: Performance metrics are now created by the trainer during model creation.
This method now just verifies they exist rather than creating duplicates.
"""
try:
logger.info("Verifying performance metrics",
models_count=len(stored_models))
# Performance metrics are already created by the trainer
# This method is kept for compatibility but doesn't create duplicates
for model in stored_models:
model_result = training_results.get("models_trained", {}).get(str(model.inventory_product_id))
if model_result and model_result.get("metrics"):
metrics = model_result["metrics"]
logger.debug("Performance metrics already created for model",
model_id=str(model.id),
inventory_product_id=str(model.inventory_product_id))
metric_data = {
"model_id": str(model.id),
"tenant_id": tenant_id,
"inventory_product_id": str(model.inventory_product_id),
"mae": metrics.get("mae"),
"mse": metrics.get("mse"),
"rmse": metrics.get("rmse"),
"mape": metrics.get("mape"),
"r2_score": metrics.get("r2_score"),
"accuracy_percentage": metrics.get("accuracy_percentage", 100 - metrics.get("mape", 0)),
"evaluation_samples": model.training_samples
}
await self.performance_repo.create_performance_metric(metric_data)
logger.info("Performance metrics verification complete",
models_count=len(stored_models))
except Exception as e:
logger.error("Failed to create performance metrics",
logger.error("Failed to verify performance metrics",
tenant_id=tenant_id,
error=str(e))