Fix deadlock issues in training
This commit is contained in:
@@ -77,6 +77,7 @@ class EnhancedBakeryMLTrainer:
|
||||
tenant_id: Tenant identifier
|
||||
training_dataset: Prepared training dataset with aligned dates
|
||||
job_id: Training job identifier
|
||||
session: Database session to use (if None, creates one)
|
||||
|
||||
Returns:
|
||||
Dictionary with training results for each product
|
||||
@@ -89,68 +90,20 @@ class EnhancedBakeryMLTrainer:
|
||||
tenant_id=tenant_id)
|
||||
|
||||
try:
|
||||
# Get database session and repositories
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
repos = await self._get_repositories(db_session)
|
||||
|
||||
# Convert sales data to DataFrame
|
||||
sales_df = pd.DataFrame(training_dataset.sales_data)
|
||||
weather_df = pd.DataFrame(training_dataset.weather_data)
|
||||
traffic_df = pd.DataFrame(training_dataset.traffic_data)
|
||||
|
||||
# Validate input data
|
||||
await self._validate_input_data(sales_df, tenant_id)
|
||||
|
||||
# Get unique products from the sales data
|
||||
products = sales_df['inventory_product_id'].unique().tolist()
|
||||
|
||||
# Debug: Log sales data details to understand why only one product is found
|
||||
total_sales_records = len(sales_df)
|
||||
sales_by_product = sales_df.groupby('inventory_product_id').size().to_dict()
|
||||
|
||||
logger.info("Enhanced training pipeline - Sales data analysis",
|
||||
total_sales_records=total_sales_records,
|
||||
products_count=len(products),
|
||||
products=products,
|
||||
sales_by_product=sales_by_product)
|
||||
|
||||
if len(products) == 1:
|
||||
logger.warning("Only ONE product found in sales data - this may indicate a data fetching issue",
|
||||
tenant_id=tenant_id,
|
||||
single_product_id=products[0],
|
||||
total_sales_records=total_sales_records)
|
||||
elif len(products) == 0:
|
||||
raise ValueError("No products found in sales data")
|
||||
else:
|
||||
logger.info("Multiple products detected for training",
|
||||
products_count=len(products))
|
||||
|
||||
# Event 1: Training Started (0%) - update with actual product count AND time estimates
|
||||
# Calculate accurate time estimates now that we know the actual product count
|
||||
from app.utils.time_estimation import (
|
||||
calculate_initial_estimate,
|
||||
calculate_estimated_completion_time,
|
||||
get_historical_average_estimate
|
||||
)
|
||||
|
||||
# Try to get historical average for more accurate estimates
|
||||
try:
|
||||
historical_avg = await get_historical_average_estimate(
|
||||
db_session,
|
||||
tenant_id
|
||||
# Use provided session or create new one to prevent nested sessions and deadlocks
|
||||
should_create_session = session is None
|
||||
db_session = session if session is not None else None
|
||||
|
||||
# Use the provided session or create a new one if needed
|
||||
if should_create_session:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
return await self._execute_training_pipeline(
|
||||
tenant_id, training_dataset, job_id, db_session
|
||||
)
|
||||
avg_time_per_product = historical_avg if historical_avg else 60.0
|
||||
logger.info("Using historical average for time estimation",
|
||||
avg_time_per_product=avg_time_per_product,
|
||||
has_historical_data=historical_avg is not None)
|
||||
except Exception as e:
|
||||
logger.warning("Could not get historical average, using default",
|
||||
error=str(e))
|
||||
avg_time_per_product = 60.0
|
||||
|
||||
estimated_duration_minutes = calculate_initial_estimate(
|
||||
total_products=len(products),
|
||||
avg_training_time_per_product=avg_time_per_product
|
||||
else:
|
||||
# Use the provided session (no context manager needed since caller manages it)
|
||||
return await self._execute_training_pipeline(
|
||||
tenant_id, training_dataset, job_id, session
|
||||
)
|
||||
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
|
||||
|
||||
@@ -177,7 +130,7 @@ class EnhancedBakeryMLTrainer:
|
||||
# Process data for each product using enhanced processor
|
||||
logger.info("Processing data using enhanced processor")
|
||||
processed_data = await self._process_all_products_enhanced(
|
||||
sales_df, weather_df, traffic_df, products, tenant_id, job_id
|
||||
sales_df, weather_df, traffic_df, products, tenant_id, job_id, session
|
||||
)
|
||||
|
||||
# Categorize all products for category-specific forecasting
|
||||
@@ -301,12 +254,220 @@ class EnhancedBakeryMLTrainer:
|
||||
await publish_training_failed(job_id, tenant_id, str(e))
|
||||
|
||||
raise
|
||||
|
||||
async def _execute_training_pipeline(self, tenant_id: str, training_dataset: TrainingDataSet,
|
||||
job_id: str, session) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the training pipeline with the given session.
|
||||
This is extracted to avoid code duplication when handling provided vs. created sessions.
|
||||
"""
|
||||
# Get repositories with the session
|
||||
repos = await self._get_repositories(session)
|
||||
|
||||
# Convert sales data to DataFrame
|
||||
sales_df = pd.DataFrame(training_dataset.sales_data)
|
||||
weather_df = pd.DataFrame(training_dataset.weather_data)
|
||||
traffic_df = pd.DataFrame(training_dataset.traffic_data)
|
||||
|
||||
# Validate input data
|
||||
await self._validate_input_data(sales_df, tenant_id)
|
||||
|
||||
# Get unique products from the sales data
|
||||
products = sales_df['inventory_product_id'].unique().tolist()
|
||||
|
||||
# Debug: Log sales data details to understand why only one product is found
|
||||
total_sales_records = len(sales_df)
|
||||
sales_by_product = sales_df.groupby('inventory_product_id').size().to_dict()
|
||||
|
||||
logger.info("Enhanced training pipeline - Sales data analysis",
|
||||
total_sales_records=total_sales_records,
|
||||
products_count=len(products),
|
||||
products=products,
|
||||
sales_by_product=sales_by_product)
|
||||
|
||||
if len(products) == 1:
|
||||
logger.warning("Only ONE product found in sales data - this may indicate a data fetching issue",
|
||||
tenant_id=tenant_id,
|
||||
single_product_id=products[0],
|
||||
total_sales_records=total_sales_records)
|
||||
elif len(products) == 0:
|
||||
raise ValueError("No products found in sales data")
|
||||
else:
|
||||
logger.info("Multiple products detected for training",
|
||||
products_count=len(products))
|
||||
|
||||
# Event 1: Training Started (0%) - update with actual product count AND time estimates
|
||||
# Calculate accurate time estimates now that we know the actual product count
|
||||
from app.utils.time_estimation import (
|
||||
calculate_initial_estimate,
|
||||
calculate_estimated_completion_time,
|
||||
get_historical_average_estimate
|
||||
)
|
||||
|
||||
# Try to get historical average for more accurate estimates
|
||||
try:
|
||||
historical_avg = await get_historical_average_estimate(
|
||||
session,
|
||||
tenant_id
|
||||
)
|
||||
avg_time_per_product = historical_avg if historical_avg else 60.0
|
||||
logger.info("Using historical average for time estimation",
|
||||
avg_time_per_product=avg_time_per_product,
|
||||
has_historical_data=historical_avg is not None)
|
||||
except Exception as e:
|
||||
logger.warning("Could not get historical average, using default",
|
||||
error=str(e))
|
||||
avg_time_per_product = 60.0
|
||||
|
||||
estimated_duration_minutes = calculate_initial_estimate(
|
||||
total_products=len(products),
|
||||
avg_training_time_per_product=avg_time_per_product
|
||||
)
|
||||
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
|
||||
|
||||
# Note: Initial event was already published by API endpoint with estimated product count,
|
||||
# this updates with real count and recalculated time estimates based on actual data
|
||||
await publish_training_started(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=len(products),
|
||||
estimated_duration_minutes=estimated_duration_minutes,
|
||||
estimated_completion_time=estimated_completion_time.isoformat()
|
||||
)
|
||||
|
||||
# Create initial training log entry
|
||||
await repos['training_log'].update_log_progress(
|
||||
job_id, 5, "data_processing", "running"
|
||||
)
|
||||
|
||||
# ✅ FIX: Flush the session to ensure the update is committed before proceeding
|
||||
# This prevents deadlocks when training methods need to acquire locks
|
||||
await session.flush()
|
||||
logger.debug("Flushed session after initial progress update")
|
||||
|
||||
# Process data for each product using enhanced processor
|
||||
logger.info("Processing data using enhanced processor")
|
||||
processed_data = await self._process_all_products_enhanced(
|
||||
sales_df, weather_df, traffic_df, products, tenant_id, job_id, session
|
||||
)
|
||||
|
||||
# Categorize all products for category-specific forecasting
|
||||
logger.info("Categorizing products for optimized forecasting")
|
||||
product_categories = await self._categorize_all_products(
|
||||
sales_df, processed_data
|
||||
)
|
||||
logger.info("Product categorization complete",
|
||||
total_products=len(product_categories),
|
||||
categories_breakdown={cat.value: sum(1 for c in product_categories.values() if c == cat)
|
||||
for cat in set(product_categories.values())})
|
||||
|
||||
# Event 2: Data Analysis (20%)
|
||||
# Recalculate time remaining based on elapsed time
|
||||
start_time = await repos['training_log'].get_start_time(job_id)
|
||||
elapsed_seconds = 0
|
||||
if start_time:
|
||||
elapsed_seconds = int((datetime.now(timezone.utc) - start_time).total_seconds())
|
||||
|
||||
# Estimate remaining time: we've done ~20% of work (data analysis)
|
||||
# Remaining 80% includes training all products
|
||||
products_to_train = len(processed_data)
|
||||
estimated_remaining_seconds = int(products_to_train * avg_time_per_product)
|
||||
|
||||
# Recalculate estimated completion time
|
||||
estimated_completion_time_data_analysis = calculate_estimated_completion_time(
|
||||
estimated_remaining_seconds / 60
|
||||
)
|
||||
|
||||
await publish_data_analysis(
|
||||
job_id,
|
||||
tenant_id,
|
||||
f"Data analysis completed for {len(processed_data)} products",
|
||||
estimated_time_remaining_seconds=estimated_remaining_seconds,
|
||||
estimated_completion_time=estimated_completion_time_data_analysis.isoformat()
|
||||
)
|
||||
|
||||
# Train models for each processed product with progress aggregation
|
||||
logger.info("Training models with repository integration and progress aggregation")
|
||||
|
||||
# Create progress tracker for parallel product training (20-80%)
|
||||
progress_tracker = ParallelProductProgressTracker(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=len(processed_data)
|
||||
)
|
||||
|
||||
# Train all models in parallel (without DB writes to avoid session conflicts)
|
||||
# ✅ FIX: Pass session to prevent nested session issues and deadlocks
|
||||
training_results = await self._train_all_models_enhanced(
|
||||
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories, session
|
||||
)
|
||||
|
||||
# Write all training results to database sequentially (after parallel training completes)
|
||||
logger.info("Writing training results to database sequentially")
|
||||
training_results = await self._write_training_results_to_database(
|
||||
tenant_id, job_id, training_results, repos
|
||||
)
|
||||
|
||||
# Calculate overall training summary with enhanced metrics
|
||||
summary = await self._calculate_enhanced_training_summary(
|
||||
training_results, repos, tenant_id
|
||||
)
|
||||
|
||||
# Calculate successful and failed trainings
|
||||
successful_trainings = len([r for r in training_results.values() if r.get('status') == 'success'])
|
||||
failed_trainings = len([r for r in training_results.values() if r.get('status') == 'error'])
|
||||
total_duration = sum([r.get('training_time_seconds', 0) for r in training_results.values()])
|
||||
|
||||
# Event 4: Training Completed (100%)
|
||||
await publish_training_completed(
|
||||
job_id,
|
||||
tenant_id,
|
||||
successful_trainings,
|
||||
failed_trainings,
|
||||
total_duration
|
||||
)
|
||||
|
||||
# Create comprehensive result with repository data
|
||||
result = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "completed",
|
||||
"products_trained": len([r for r in training_results.values() if r.get('status') == 'success']),
|
||||
"products_failed": len([r for r in training_results.values() if r.get('status') == 'error']),
|
||||
"products_skipped": len([r for r in training_results.values() if r.get('status') == 'skipped']),
|
||||
"total_products": len(products),
|
||||
"training_results": training_results,
|
||||
"enhanced_summary": summary,
|
||||
"models_trained": summary.get('models_created', {}),
|
||||
"data_info": {
|
||||
"date_range": {
|
||||
"start": training_dataset.date_range.start.isoformat(),
|
||||
"end": training_dataset.date_range.end.isoformat(),
|
||||
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
|
||||
},
|
||||
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
|
||||
"constraints_applied": training_dataset.date_range.constraints
|
||||
},
|
||||
"repository_metadata": {
|
||||
"total_records_created": summary.get('total_db_records', 0),
|
||||
"performance_metrics_stored": summary.get('performance_metrics_created', 0),
|
||||
"artifacts_created": summary.get('artifacts_created', 0)
|
||||
},
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info("Enhanced ML training pipeline completed successfully",
|
||||
job_id=job_id,
|
||||
models_created=len([r for r in training_results.values() if r.get('status') == 'success']))
|
||||
|
||||
return result
|
||||
|
||||
async def train_single_product_model(self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
training_data: pd.DataFrame,
|
||||
job_id: str = None) -> Dict[str, Any]:
|
||||
job_id: str = None,
|
||||
session=None) -> Dict[str, Any]:
|
||||
"""
|
||||
Train a model for a single product using repository pattern.
|
||||
|
||||
@@ -315,6 +476,7 @@ class EnhancedBakeryMLTrainer:
|
||||
inventory_product_id: Specific inventory product to train
|
||||
training_data: Prepared training DataFrame for the product
|
||||
job_id: Training job identifier (optional)
|
||||
session: Database session to use (if None, creates one)
|
||||
|
||||
Returns:
|
||||
Dictionary with model training results
|
||||
@@ -329,9 +491,14 @@ class EnhancedBakeryMLTrainer:
|
||||
data_points=len(training_data))
|
||||
|
||||
try:
|
||||
# Get database session and repositories
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
repos = await self._get_repositories(db_session)
|
||||
# Use provided session or create new one to prevent nested sessions and deadlocks
|
||||
should_create_session = session is None
|
||||
db_session = session if session is not None else None
|
||||
|
||||
if should_create_session:
|
||||
# Only create a session if one wasn't provided
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
repos = await self._get_repositories(db_session)
|
||||
|
||||
# Validate input data
|
||||
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||
@@ -367,7 +534,8 @@ class EnhancedBakeryMLTrainer:
|
||||
product_data=training_data,
|
||||
job_id=job_id,
|
||||
repos=repos,
|
||||
progress_tracker=progress_tracker
|
||||
progress_tracker=progress_tracker,
|
||||
session=db_session # Pass the session to prevent nested sessions
|
||||
)
|
||||
|
||||
logger.info("Single product training completed",
|
||||
@@ -391,7 +559,7 @@ class EnhancedBakeryMLTrainer:
|
||||
continue
|
||||
|
||||
# Return appropriate result format
|
||||
return {
|
||||
result_dict = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
@@ -403,6 +571,8 @@ class EnhancedBakeryMLTrainer:
|
||||
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
|
||||
}
|
||||
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Single product model training failed",
|
||||
job_id=job_id,
|
||||
@@ -452,7 +622,8 @@ class EnhancedBakeryMLTrainer:
|
||||
traffic_df: pd.DataFrame,
|
||||
products: List[str],
|
||||
tenant_id: str,
|
||||
job_id: str) -> Dict[str, pd.DataFrame]:
|
||||
job_id: str,
|
||||
session=None) -> Dict[str, pd.DataFrame]:
|
||||
"""Process data for all products using enhanced processor with repository tracking"""
|
||||
processed_data = {}
|
||||
|
||||
@@ -470,13 +641,15 @@ class EnhancedBakeryMLTrainer:
|
||||
continue
|
||||
|
||||
# Use enhanced data processor with repository tracking
|
||||
# Pass the session to prevent nested session issues
|
||||
processed_product_data = await self.enhanced_data_processor.prepare_training_data(
|
||||
sales_data=product_sales,
|
||||
weather_data=weather_df,
|
||||
traffic_data=traffic_df,
|
||||
inventory_product_id=inventory_product_id,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id
|
||||
job_id=job_id,
|
||||
session=session # Pass the session to avoid creating new ones
|
||||
)
|
||||
|
||||
processed_data[inventory_product_id] = processed_product_data
|
||||
|
||||
Reference in New Issue
Block a user