Fix orchestrator issues
This commit is contained in:
@@ -105,145 +105,7 @@ class EnhancedBakeryMLTrainer:
|
||||
return await self._execute_training_pipeline(
|
||||
tenant_id, training_dataset, job_id, session
|
||||
)
|
||||
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
|
||||
|
||||
# Note: Initial event was already published by API endpoint with estimated product count,
|
||||
# this updates with real count and recalculated time estimates based on actual data
|
||||
await publish_training_started(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=len(products),
|
||||
estimated_duration_minutes=estimated_duration_minutes,
|
||||
estimated_completion_time=estimated_completion_time.isoformat()
|
||||
)
|
||||
|
||||
# Create initial training log entry
|
||||
await repos['training_log'].update_log_progress(
|
||||
job_id, 5, "data_processing", "running"
|
||||
)
|
||||
|
||||
# ✅ FIX: Flush the session to ensure the update is committed before proceeding
|
||||
# This prevents deadlocks when training methods need to acquire locks
|
||||
await db_session.flush()
|
||||
logger.debug("Flushed session after initial progress update")
|
||||
|
||||
# Process data for each product using enhanced processor
|
||||
logger.info("Processing data using enhanced processor")
|
||||
processed_data = await self._process_all_products_enhanced(
|
||||
sales_df, weather_df, traffic_df, products, tenant_id, job_id, session
|
||||
)
|
||||
|
||||
# Categorize all products for category-specific forecasting
|
||||
logger.info("Categorizing products for optimized forecasting")
|
||||
product_categories = await self._categorize_all_products(
|
||||
sales_df, processed_data
|
||||
)
|
||||
logger.info("Product categorization complete",
|
||||
total_products=len(product_categories),
|
||||
categories_breakdown={cat.value: sum(1 for c in product_categories.values() if c == cat)
|
||||
for cat in set(product_categories.values())})
|
||||
|
||||
# Event 2: Data Analysis (20%)
|
||||
# Recalculate time remaining based on elapsed time
|
||||
start_time = await repos['training_log'].get_start_time(job_id)
|
||||
elapsed_seconds = 0
|
||||
if start_time:
|
||||
elapsed_seconds = int((datetime.now(timezone.utc) - start_time).total_seconds())
|
||||
|
||||
# Estimate remaining time: we've done ~20% of work (data analysis)
|
||||
# Remaining 80% includes training all products
|
||||
products_to_train = len(processed_data)
|
||||
estimated_remaining_seconds = int(products_to_train * avg_time_per_product)
|
||||
|
||||
# Recalculate estimated completion time
|
||||
estimated_completion_time_data_analysis = calculate_estimated_completion_time(
|
||||
estimated_remaining_seconds / 60
|
||||
)
|
||||
|
||||
await publish_data_analysis(
|
||||
job_id,
|
||||
tenant_id,
|
||||
f"Data analysis completed for {len(processed_data)} products",
|
||||
estimated_time_remaining_seconds=estimated_remaining_seconds,
|
||||
estimated_completion_time=estimated_completion_time_data_analysis.isoformat()
|
||||
)
|
||||
|
||||
# Train models for each processed product with progress aggregation
|
||||
logger.info("Training models with repository integration and progress aggregation")
|
||||
|
||||
# Create progress tracker for parallel product training (20-80%)
|
||||
progress_tracker = ParallelProductProgressTracker(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=len(processed_data)
|
||||
)
|
||||
|
||||
# Train all models in parallel (without DB writes to avoid session conflicts)
|
||||
# ✅ FIX: Pass db_session to prevent nested session issues and deadlocks
|
||||
training_results = await self._train_all_models_enhanced(
|
||||
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories, db_session
|
||||
)
|
||||
|
||||
# Write all training results to database sequentially (after parallel training completes)
|
||||
logger.info("Writing training results to database sequentially")
|
||||
training_results = await self._write_training_results_to_database(
|
||||
tenant_id, job_id, training_results, repos
|
||||
)
|
||||
|
||||
# Calculate overall training summary with enhanced metrics
|
||||
summary = await self._calculate_enhanced_training_summary(
|
||||
training_results, repos, tenant_id
|
||||
)
|
||||
|
||||
# Calculate successful and failed trainings
|
||||
successful_trainings = len([r for r in training_results.values() if r.get('status') == 'success'])
|
||||
failed_trainings = len([r for r in training_results.values() if r.get('status') == 'error'])
|
||||
total_duration = sum([r.get('training_time_seconds', 0) for r in training_results.values()])
|
||||
|
||||
# Event 4: Training Completed (100%)
|
||||
await publish_training_completed(
|
||||
job_id,
|
||||
tenant_id,
|
||||
successful_trainings,
|
||||
failed_trainings,
|
||||
total_duration
|
||||
)
|
||||
|
||||
# Create comprehensive result with repository data
|
||||
result = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "completed",
|
||||
"products_trained": len([r for r in training_results.values() if r.get('status') == 'success']),
|
||||
"products_failed": len([r for r in training_results.values() if r.get('status') == 'error']),
|
||||
"products_skipped": len([r for r in training_results.values() if r.get('status') == 'skipped']),
|
||||
"total_products": len(products),
|
||||
"training_results": training_results,
|
||||
"enhanced_summary": summary,
|
||||
"models_trained": summary.get('models_created', {}),
|
||||
"data_info": {
|
||||
"date_range": {
|
||||
"start": training_dataset.date_range.start.isoformat(),
|
||||
"end": training_dataset.date_range.end.isoformat(),
|
||||
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
|
||||
},
|
||||
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
|
||||
"constraints_applied": training_dataset.date_range.constraints
|
||||
},
|
||||
"repository_metadata": {
|
||||
"total_records_created": summary.get('total_db_records', 0),
|
||||
"performance_metrics_stored": summary.get('performance_metrics_created', 0),
|
||||
"artifacts_created": summary.get('artifacts_created', 0)
|
||||
},
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info("Enhanced ML training pipeline completed successfully",
|
||||
job_id=job_id,
|
||||
models_created=len([r for r in training_results.values() if r.get('status') == 'success']))
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Enhanced ML training pipeline failed",
|
||||
job_id=job_id,
|
||||
@@ -408,6 +270,12 @@ class EnhancedBakeryMLTrainer:
|
||||
tenant_id, job_id, training_results, repos
|
||||
)
|
||||
|
||||
# ✅ CRITICAL FIX: Commit the session to persist model records to database
|
||||
# Without this commit, all model records created above are lost when session closes
|
||||
await session.commit()
|
||||
logger.info("Committed model records to database",
|
||||
models_created=len([r for r in training_results.values() if 'model_record_id' in r]))
|
||||
|
||||
# Calculate overall training summary with enhanced metrics
|
||||
summary = await self._calculate_enhanced_training_summary(
|
||||
training_results, repos, tenant_id
|
||||
|
||||
Reference in New Issue
Block a user