Fix deadlock issues in training

This commit is contained in:
Urtzi Alfaro
2025-11-05 18:47:20 +01:00
parent fd0a96e254
commit 74215d3e85
3 changed files with 620 additions and 131 deletions

View File

@@ -77,6 +77,7 @@ class EnhancedBakeryMLTrainer:
tenant_id: Tenant identifier
training_dataset: Prepared training dataset with aligned dates
job_id: Training job identifier
session: Database session to use (if None, creates one)
Returns:
Dictionary with training results for each product
@@ -89,68 +90,20 @@ class EnhancedBakeryMLTrainer:
tenant_id=tenant_id)
try:
# Get database session and repositories
async with self.database_manager.get_session() as db_session:
repos = await self._get_repositories(db_session)
# Convert sales data to DataFrame
sales_df = pd.DataFrame(training_dataset.sales_data)
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
# Validate input data
await self._validate_input_data(sales_df, tenant_id)
# Get unique products from the sales data
products = sales_df['inventory_product_id'].unique().tolist()
# Debug: Log sales data details to understand why only one product is found
total_sales_records = len(sales_df)
sales_by_product = sales_df.groupby('inventory_product_id').size().to_dict()
logger.info("Enhanced training pipeline - Sales data analysis",
total_sales_records=total_sales_records,
products_count=len(products),
products=products,
sales_by_product=sales_by_product)
if len(products) == 1:
logger.warning("Only ONE product found in sales data - this may indicate a data fetching issue",
tenant_id=tenant_id,
single_product_id=products[0],
total_sales_records=total_sales_records)
elif len(products) == 0:
raise ValueError("No products found in sales data")
else:
logger.info("Multiple products detected for training",
products_count=len(products))
# Event 1: Training Started (0%) - update with actual product count AND time estimates
# Calculate accurate time estimates now that we know the actual product count
from app.utils.time_estimation import (
calculate_initial_estimate,
calculate_estimated_completion_time,
get_historical_average_estimate
)
# Try to get historical average for more accurate estimates
try:
historical_avg = await get_historical_average_estimate(
db_session,
tenant_id
# Use provided session or create new one to prevent nested sessions and deadlocks
should_create_session = session is None
db_session = session if session is not None else None
# Use the provided session or create a new one if needed
if should_create_session:
async with self.database_manager.get_session() as db_session:
return await self._execute_training_pipeline(
tenant_id, training_dataset, job_id, db_session
)
avg_time_per_product = historical_avg if historical_avg else 60.0
logger.info("Using historical average for time estimation",
avg_time_per_product=avg_time_per_product,
has_historical_data=historical_avg is not None)
except Exception as e:
logger.warning("Could not get historical average, using default",
error=str(e))
avg_time_per_product = 60.0
estimated_duration_minutes = calculate_initial_estimate(
total_products=len(products),
avg_training_time_per_product=avg_time_per_product
else:
# Use the provided session (no context manager needed since caller manages it)
return await self._execute_training_pipeline(
tenant_id, training_dataset, job_id, session
)
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
@@ -177,7 +130,7 @@ class EnhancedBakeryMLTrainer:
# Process data for each product using enhanced processor
logger.info("Processing data using enhanced processor")
processed_data = await self._process_all_products_enhanced(
sales_df, weather_df, traffic_df, products, tenant_id, job_id
sales_df, weather_df, traffic_df, products, tenant_id, job_id, session
)
# Categorize all products for category-specific forecasting
@@ -301,12 +254,220 @@ class EnhancedBakeryMLTrainer:
await publish_training_failed(job_id, tenant_id, str(e))
raise
async def _execute_training_pipeline(self, tenant_id: str, training_dataset: TrainingDataSet,
job_id: str, session) -> Dict[str, Any]:
"""
Execute the training pipeline with the given session.
This is extracted to avoid code duplication when handling provided vs. created sessions.
"""
# Get repositories with the session
repos = await self._get_repositories(session)
# Convert sales data to DataFrame
sales_df = pd.DataFrame(training_dataset.sales_data)
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
# Validate input data
await self._validate_input_data(sales_df, tenant_id)
# Get unique products from the sales data
products = sales_df['inventory_product_id'].unique().tolist()
# Debug: Log sales data details to understand why only one product is found
total_sales_records = len(sales_df)
sales_by_product = sales_df.groupby('inventory_product_id').size().to_dict()
logger.info("Enhanced training pipeline - Sales data analysis",
total_sales_records=total_sales_records,
products_count=len(products),
products=products,
sales_by_product=sales_by_product)
if len(products) == 1:
logger.warning("Only ONE product found in sales data - this may indicate a data fetching issue",
tenant_id=tenant_id,
single_product_id=products[0],
total_sales_records=total_sales_records)
elif len(products) == 0:
raise ValueError("No products found in sales data")
else:
logger.info("Multiple products detected for training",
products_count=len(products))
# Event 1: Training Started (0%) - update with actual product count AND time estimates
# Calculate accurate time estimates now that we know the actual product count
from app.utils.time_estimation import (
calculate_initial_estimate,
calculate_estimated_completion_time,
get_historical_average_estimate
)
# Try to get historical average for more accurate estimates
try:
historical_avg = await get_historical_average_estimate(
session,
tenant_id
)
avg_time_per_product = historical_avg if historical_avg else 60.0
logger.info("Using historical average for time estimation",
avg_time_per_product=avg_time_per_product,
has_historical_data=historical_avg is not None)
except Exception as e:
logger.warning("Could not get historical average, using default",
error=str(e))
avg_time_per_product = 60.0
estimated_duration_minutes = calculate_initial_estimate(
total_products=len(products),
avg_training_time_per_product=avg_time_per_product
)
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
# Note: Initial event was already published by API endpoint with estimated product count,
# this updates with real count and recalculated time estimates based on actual data
await publish_training_started(
job_id=job_id,
tenant_id=tenant_id,
total_products=len(products),
estimated_duration_minutes=estimated_duration_minutes,
estimated_completion_time=estimated_completion_time.isoformat()
)
# Create initial training log entry
await repos['training_log'].update_log_progress(
job_id, 5, "data_processing", "running"
)
# ✅ FIX: Flush the session to ensure the update is committed before proceeding
# This prevents deadlocks when training methods need to acquire locks
await session.flush()
logger.debug("Flushed session after initial progress update")
# Process data for each product using enhanced processor
logger.info("Processing data using enhanced processor")
processed_data = await self._process_all_products_enhanced(
sales_df, weather_df, traffic_df, products, tenant_id, job_id, session
)
# Categorize all products for category-specific forecasting
logger.info("Categorizing products for optimized forecasting")
product_categories = await self._categorize_all_products(
sales_df, processed_data
)
logger.info("Product categorization complete",
total_products=len(product_categories),
categories_breakdown={cat.value: sum(1 for c in product_categories.values() if c == cat)
for cat in set(product_categories.values())})
# Event 2: Data Analysis (20%)
# Recalculate time remaining based on elapsed time
start_time = await repos['training_log'].get_start_time(job_id)
elapsed_seconds = 0
if start_time:
elapsed_seconds = int((datetime.now(timezone.utc) - start_time).total_seconds())
# Estimate remaining time: we've done ~20% of work (data analysis)
# Remaining 80% includes training all products
products_to_train = len(processed_data)
estimated_remaining_seconds = int(products_to_train * avg_time_per_product)
# Recalculate estimated completion time
estimated_completion_time_data_analysis = calculate_estimated_completion_time(
estimated_remaining_seconds / 60
)
await publish_data_analysis(
job_id,
tenant_id,
f"Data analysis completed for {len(processed_data)} products",
estimated_time_remaining_seconds=estimated_remaining_seconds,
estimated_completion_time=estimated_completion_time_data_analysis.isoformat()
)
# Train models for each processed product with progress aggregation
logger.info("Training models with repository integration and progress aggregation")
# Create progress tracker for parallel product training (20-80%)
progress_tracker = ParallelProductProgressTracker(
job_id=job_id,
tenant_id=tenant_id,
total_products=len(processed_data)
)
# Train all models in parallel (without DB writes to avoid session conflicts)
# ✅ FIX: Pass session to prevent nested session issues and deadlocks
training_results = await self._train_all_models_enhanced(
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories, session
)
# Write all training results to database sequentially (after parallel training completes)
logger.info("Writing training results to database sequentially")
training_results = await self._write_training_results_to_database(
tenant_id, job_id, training_results, repos
)
# Calculate overall training summary with enhanced metrics
summary = await self._calculate_enhanced_training_summary(
training_results, repos, tenant_id
)
# Calculate successful and failed trainings
successful_trainings = len([r for r in training_results.values() if r.get('status') == 'success'])
failed_trainings = len([r for r in training_results.values() if r.get('status') == 'error'])
total_duration = sum([r.get('training_time_seconds', 0) for r in training_results.values()])
# Event 4: Training Completed (100%)
await publish_training_completed(
job_id,
tenant_id,
successful_trainings,
failed_trainings,
total_duration
)
# Create comprehensive result with repository data
result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "completed",
"products_trained": len([r for r in training_results.values() if r.get('status') == 'success']),
"products_failed": len([r for r in training_results.values() if r.get('status') == 'error']),
"products_skipped": len([r for r in training_results.values() if r.get('status') == 'skipped']),
"total_products": len(products),
"training_results": training_results,
"enhanced_summary": summary,
"models_trained": summary.get('models_created', {}),
"data_info": {
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat(),
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
},
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
},
"repository_metadata": {
"total_records_created": summary.get('total_db_records', 0),
"performance_metrics_stored": summary.get('performance_metrics_created', 0),
"artifacts_created": summary.get('artifacts_created', 0)
},
"completed_at": datetime.now().isoformat()
}
logger.info("Enhanced ML training pipeline completed successfully",
job_id=job_id,
models_created=len([r for r in training_results.values() if r.get('status') == 'success']))
return result
async def train_single_product_model(self,
tenant_id: str,
inventory_product_id: str,
training_data: pd.DataFrame,
job_id: str = None) -> Dict[str, Any]:
job_id: str = None,
session=None) -> Dict[str, Any]:
"""
Train a model for a single product using repository pattern.
@@ -315,6 +476,7 @@ class EnhancedBakeryMLTrainer:
inventory_product_id: Specific inventory product to train
training_data: Prepared training DataFrame for the product
job_id: Training job identifier (optional)
session: Database session to use (if None, creates one)
Returns:
Dictionary with model training results
@@ -329,9 +491,14 @@ class EnhancedBakeryMLTrainer:
data_points=len(training_data))
try:
# Get database session and repositories
async with self.database_manager.get_session() as db_session:
repos = await self._get_repositories(db_session)
# Use provided session or create new one to prevent nested sessions and deadlocks
should_create_session = session is None
db_session = session if session is not None else None
if should_create_session:
# Only create a session if one wasn't provided
async with self.database_manager.get_session() as db_session:
repos = await self._get_repositories(db_session)
# Validate input data
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
@@ -367,7 +534,8 @@ class EnhancedBakeryMLTrainer:
product_data=training_data,
job_id=job_id,
repos=repos,
progress_tracker=progress_tracker
progress_tracker=progress_tracker,
session=db_session # Pass the session to prevent nested sessions
)
logger.info("Single product training completed",
@@ -391,7 +559,7 @@ class EnhancedBakeryMLTrainer:
continue
# Return appropriate result format
return {
result_dict = {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
@@ -403,6 +571,8 @@ class EnhancedBakeryMLTrainer:
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
}
return result_dict
except Exception as e:
logger.error("Single product model training failed",
job_id=job_id,
@@ -452,7 +622,8 @@ class EnhancedBakeryMLTrainer:
traffic_df: pd.DataFrame,
products: List[str],
tenant_id: str,
job_id: str) -> Dict[str, pd.DataFrame]:
job_id: str,
session=None) -> Dict[str, pd.DataFrame]:
"""Process data for all products using enhanced processor with repository tracking"""
processed_data = {}
@@ -470,13 +641,15 @@ class EnhancedBakeryMLTrainer:
continue
# Use enhanced data processor with repository tracking
# Pass the session to prevent nested session issues
processed_product_data = await self.enhanced_data_processor.prepare_training_data(
sales_data=product_sales,
weather_data=weather_df,
traffic_data=traffic_df,
inventory_product_id=inventory_product_id,
tenant_id=tenant_id,
job_id=job_id
job_id=job_id,
session=session # Pass the session to avoid creating new ones
)
processed_data[inventory_product_id] = processed_product_data