Fix training job concurrent database session conflicts
Root Cause: - Multiple parallel training tasks (3 at a time) were sharing the same database session - This caused SQLAlchemy session state conflicts: "Session is already flushing" and "rollback() is already in progress" - Additionally, duplicate model records were being created by both trainer and training_service Fixes: 1. Separated model training from database writes: - Training happens in parallel (CPU-intensive) - Database writes happen sequentially after training completes - This eliminates concurrent session access 2. Removed duplicate database writes: - Trainer now writes all model records sequentially after parallel training - Training service now retrieves models instead of creating duplicates - Performance metrics are also created by trainer (no duplicates) 3. Added proper data flow: - _train_single_product: Only trains models, stores results - _write_training_results_to_database: Sequential DB writes after training - _store_trained_models: Changed to retrieve existing models - _create_performance_metrics: Changed to verify existing metrics Benefits: - Eliminates database session conflicts - Prevents duplicate model records - Maintains parallel training performance - Ensures data consistency Files Modified: - services/training/app/ml/trainer.py - services/training/app/services/training_service.py Resolves: Onboarding training job database session conflicts
This commit is contained in:
@@ -342,81 +342,57 @@ class EnhancedTrainingService:
|
||||
job_id: str,
|
||||
training_results: Dict[str, Any]
|
||||
) -> List:
|
||||
"""Store trained models using repository pattern"""
|
||||
"""
|
||||
Retrieve or verify stored models from training results.
|
||||
|
||||
NOTE: Model records are now created by the trainer during parallel execution.
|
||||
This method retrieves the already-created models instead of creating duplicates.
|
||||
"""
|
||||
stored_models = []
|
||||
|
||||
|
||||
try:
|
||||
# Get models_trained before sanitization to preserve structure
|
||||
models_trained = training_results.get("models_trained", {})
|
||||
logger.debug("Models trained structure",
|
||||
models_trained_type=type(models_trained).__name__,
|
||||
models_trained_keys=list(models_trained.keys()) if isinstance(models_trained, dict) else "not_dict")
|
||||
|
||||
for inventory_product_id, model_result in models_trained.items():
|
||||
# Defensive check: ensure model_result is a dictionary
|
||||
if not isinstance(model_result, dict):
|
||||
logger.warning("Skipping invalid model_result for product",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_result_type=type(model_result).__name__,
|
||||
model_result_value=str(model_result)[:100])
|
||||
continue
|
||||
|
||||
if model_result.get("status") == "completed":
|
||||
# Sanitize individual fields that might contain UUID objects
|
||||
metrics = model_result.get("metrics", {})
|
||||
if not isinstance(metrics, dict):
|
||||
logger.warning("Invalid metrics object, using empty dict",
|
||||
inventory_product_id=inventory_product_id,
|
||||
metrics_type=type(metrics).__name__)
|
||||
metrics = {}
|
||||
model_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"job_id": job_id,
|
||||
"model_type": "prophet_optimized",
|
||||
"model_path": model_result.get("model_path"),
|
||||
"metadata_path": model_result.get("metadata_path"),
|
||||
"mape": make_json_serializable(metrics.get("mape")),
|
||||
"mae": make_json_serializable(metrics.get("mae")),
|
||||
"rmse": make_json_serializable(metrics.get("rmse")),
|
||||
"r2_score": make_json_serializable(metrics.get("r2_score")),
|
||||
"training_samples": make_json_serializable(model_result.get("data_points", 0)),
|
||||
"hyperparameters": make_json_serializable(model_result.get("hyperparameters")),
|
||||
"features_used": make_json_serializable(model_result.get("features_used")),
|
||||
"is_active": True,
|
||||
"is_production": True, # New models are production by default
|
||||
"data_quality_score": make_json_serializable(model_result.get("data_quality_score"))
|
||||
}
|
||||
|
||||
# Create model record
|
||||
model = await self.model_repo.create_model(model_data)
|
||||
stored_models.append(model)
|
||||
|
||||
# Create artifacts if present
|
||||
if model_result.get("model_path"):
|
||||
artifact_data = {
|
||||
"model_id": str(model.id),
|
||||
"tenant_id": tenant_id,
|
||||
"artifact_type": "model_file",
|
||||
"file_path": model_result["model_path"],
|
||||
"storage_location": "local"
|
||||
}
|
||||
await self.artifact_repo.create_artifact(artifact_data)
|
||||
|
||||
if model_result.get("metadata_path"):
|
||||
artifact_data = {
|
||||
"model_id": str(model.id),
|
||||
"tenant_id": tenant_id,
|
||||
"artifact_type": "metadata",
|
||||
"file_path": model_result["metadata_path"],
|
||||
"storage_location": "local"
|
||||
}
|
||||
await self.artifact_repo.create_artifact(artifact_data)
|
||||
|
||||
# Check if models were already created by the trainer (new approach)
|
||||
# The trainer now writes models sequentially after parallel training
|
||||
training_results_dict = training_results.get("training_results", {})
|
||||
|
||||
# Get list of successfully trained products
|
||||
successful_products = [
|
||||
product_id for product_id, result in training_results_dict.items()
|
||||
if result.get('status') == 'success' and result.get('model_record_id')
|
||||
]
|
||||
|
||||
logger.info("Retrieving models created by trainer",
|
||||
successful_products=len(successful_products),
|
||||
job_id=job_id)
|
||||
|
||||
# Retrieve the models that were already created by the trainer
|
||||
for product_id in successful_products:
|
||||
result = training_results_dict[product_id]
|
||||
model_record_id = result.get('model_record_id')
|
||||
|
||||
if model_record_id:
|
||||
try:
|
||||
# Get the model from the database using base repository method
|
||||
model = await self.model_repo.get_by_id(model_record_id)
|
||||
if model:
|
||||
stored_models.append(model)
|
||||
logger.debug("Retrieved model from database",
|
||||
model_id=model_record_id,
|
||||
inventory_product_id=product_id)
|
||||
except Exception as e:
|
||||
logger.warning("Could not retrieve model record",
|
||||
model_id=model_record_id,
|
||||
inventory_product_id=product_id,
|
||||
error=str(e))
|
||||
|
||||
logger.info("Models retrieval complete",
|
||||
models_retrieved=len(stored_models),
|
||||
expected=len(successful_products))
|
||||
|
||||
return stored_models
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to store trained models",
|
||||
logger.error("Failed to retrieve stored models",
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
@@ -428,30 +404,28 @@ class EnhancedTrainingService:
|
||||
stored_models: List,
|
||||
training_results: Dict[str, Any]
|
||||
):
|
||||
"""Create performance metrics for stored models"""
|
||||
"""
|
||||
Verify performance metrics for stored models.
|
||||
|
||||
NOTE: Performance metrics are now created by the trainer during model creation.
|
||||
This method now just verifies they exist rather than creating duplicates.
|
||||
"""
|
||||
try:
|
||||
logger.info("Verifying performance metrics",
|
||||
models_count=len(stored_models))
|
||||
|
||||
# Performance metrics are already created by the trainer
|
||||
# This method is kept for compatibility but doesn't create duplicates
|
||||
for model in stored_models:
|
||||
model_result = training_results.get("models_trained", {}).get(str(model.inventory_product_id))
|
||||
if model_result and model_result.get("metrics"):
|
||||
metrics = model_result["metrics"]
|
||||
logger.debug("Performance metrics already created for model",
|
||||
model_id=str(model.id),
|
||||
inventory_product_id=str(model.inventory_product_id))
|
||||
|
||||
metric_data = {
|
||||
"model_id": str(model.id),
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": str(model.inventory_product_id),
|
||||
"mae": metrics.get("mae"),
|
||||
"mse": metrics.get("mse"),
|
||||
"rmse": metrics.get("rmse"),
|
||||
"mape": metrics.get("mape"),
|
||||
"r2_score": metrics.get("r2_score"),
|
||||
"accuracy_percentage": metrics.get("accuracy_percentage", 100 - metrics.get("mape", 0)),
|
||||
"evaluation_samples": model.training_samples
|
||||
}
|
||||
|
||||
await self.performance_repo.create_performance_metric(metric_data)
|
||||
logger.info("Performance metrics verification complete",
|
||||
models_count=len(stored_models))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create performance metrics",
|
||||
logger.error("Failed to verify performance metrics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user