REFACTOR external service and improve websocket training
This commit is contained in:
@@ -10,6 +10,7 @@ from datetime import datetime
|
||||
import structlog
|
||||
import uuid
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from app.ml.data_processor import EnhancedBakeryDataProcessor
|
||||
from app.ml.prophet_manager import BakeryProphetManager
|
||||
@@ -28,7 +29,13 @@ from app.repositories import (
|
||||
ArtifactRepository
|
||||
)
|
||||
|
||||
from app.services.messaging import TrainingStatusPublisher
|
||||
from app.services.progress_tracker import ParallelProductProgressTracker
|
||||
from app.services.training_events import (
|
||||
publish_training_started,
|
||||
publish_data_analysis,
|
||||
publish_training_completed,
|
||||
publish_training_failed
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -75,8 +82,6 @@ class EnhancedBakeryMLTrainer:
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
self.status_publisher = TrainingStatusPublisher(job_id, tenant_id)
|
||||
|
||||
try:
|
||||
# Get database session and repositories
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
@@ -113,8 +118,10 @@ class EnhancedBakeryMLTrainer:
|
||||
else:
|
||||
logger.info("Multiple products detected for training",
|
||||
products_count=len(products))
|
||||
|
||||
self.status_publisher.products_total = len(products)
|
||||
|
||||
# Event 1: Training Started (0%) - update with actual product count
|
||||
# Note: Initial event was already published by API endpoint, this updates with real count
|
||||
await publish_training_started(job_id, tenant_id, len(products))
|
||||
|
||||
# Create initial training log entry
|
||||
await repos['training_log'].update_log_progress(
|
||||
@@ -126,28 +133,45 @@ class EnhancedBakeryMLTrainer:
|
||||
processed_data = await self._process_all_products_enhanced(
|
||||
sales_df, weather_df, traffic_df, products, tenant_id, job_id
|
||||
)
|
||||
|
||||
await self.status_publisher.progress_update(
|
||||
progress=20,
|
||||
step="feature_engineering",
|
||||
step_details="Enhanced processing with repository tracking"
|
||||
|
||||
# Event 2: Data Analysis (20%)
|
||||
await publish_data_analysis(
|
||||
job_id,
|
||||
tenant_id,
|
||||
f"Data analysis completed for {len(processed_data)} products"
|
||||
)
|
||||
|
||||
# Train models for each processed product
|
||||
logger.info("Training models with repository integration")
|
||||
# Train models for each processed product with progress aggregation
|
||||
logger.info("Training models with repository integration and progress aggregation")
|
||||
|
||||
# Create progress tracker for parallel product training (20-80%)
|
||||
progress_tracker = ParallelProductProgressTracker(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=len(processed_data)
|
||||
)
|
||||
|
||||
training_results = await self._train_all_models_enhanced(
|
||||
tenant_id, processed_data, job_id, repos
|
||||
tenant_id, processed_data, job_id, repos, progress_tracker
|
||||
)
|
||||
|
||||
# Calculate overall training summary with enhanced metrics
|
||||
summary = await self._calculate_enhanced_training_summary(
|
||||
training_results, repos, tenant_id
|
||||
)
|
||||
|
||||
await self.status_publisher.progress_update(
|
||||
progress=90,
|
||||
step="model_validation",
|
||||
step_details="Enhanced validation with repository tracking"
|
||||
|
||||
# Calculate successful and failed trainings
|
||||
successful_trainings = len([r for r in training_results.values() if r.get('status') == 'success'])
|
||||
failed_trainings = len([r for r in training_results.values() if r.get('status') == 'error'])
|
||||
total_duration = sum([r.get('training_time_seconds', 0) for r in training_results.values()])
|
||||
|
||||
# Event 4: Training Completed (100%)
|
||||
await publish_training_completed(
|
||||
job_id,
|
||||
tenant_id,
|
||||
successful_trainings,
|
||||
failed_trainings,
|
||||
total_duration
|
||||
)
|
||||
|
||||
# Create comprehensive result with repository data
|
||||
@@ -189,6 +213,10 @@ class EnhancedBakeryMLTrainer:
|
||||
logger.error("Enhanced ML training pipeline failed",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
|
||||
# Publish training failed event
|
||||
await publish_training_failed(job_id, tenant_id, str(e))
|
||||
|
||||
raise
|
||||
|
||||
async def _process_all_products_enhanced(self,
|
||||
@@ -237,111 +265,158 @@ class EnhancedBakeryMLTrainer:
|
||||
|
||||
return processed_data
|
||||
|
||||
async def _train_single_product(self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
product_data: pd.DataFrame,
|
||||
job_id: str,
|
||||
repos: Dict,
|
||||
progress_tracker: ParallelProductProgressTracker) -> tuple[str, Dict[str, Any]]:
|
||||
"""Train a single product model - used for parallel execution with progress aggregation"""
|
||||
product_start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.info("Training model", inventory_product_id=inventory_product_id)
|
||||
|
||||
# Check if we have enough data
|
||||
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||
result = {
|
||||
'status': 'skipped',
|
||||
'reason': 'insufficient_data',
|
||||
'data_points': len(product_data),
|
||||
'min_required': settings.MIN_TRAINING_DATA_DAYS,
|
||||
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
|
||||
}
|
||||
logger.warning("Skipping product due to insufficient data",
|
||||
inventory_product_id=inventory_product_id,
|
||||
data_points=len(product_data),
|
||||
min_required=settings.MIN_TRAINING_DATA_DAYS)
|
||||
return inventory_product_id, result
|
||||
|
||||
# Train the model using Prophet manager
|
||||
model_info = await self.prophet_manager.train_bakery_model(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
df=product_data,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
# Store model record using repository
|
||||
model_record = await self._create_model_record(
|
||||
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
|
||||
)
|
||||
|
||||
# Create performance metrics record
|
||||
if model_info.get('training_metrics'):
|
||||
await self._create_performance_metrics(
|
||||
repos, model_record.id if model_record else None,
|
||||
tenant_id, inventory_product_id, model_info['training_metrics']
|
||||
)
|
||||
|
||||
result = {
|
||||
'status': 'success',
|
||||
'model_info': model_info,
|
||||
'model_record_id': model_record.id if model_record else None,
|
||||
'data_points': len(product_data),
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'trained_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info("Successfully trained model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_record_id=model_record.id if model_record else None)
|
||||
|
||||
# Report completion to progress tracker (emits Event 3: product_completed)
|
||||
await progress_tracker.mark_product_completed(inventory_product_id)
|
||||
|
||||
return inventory_product_id, result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to train model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
result = {
|
||||
'status': 'error',
|
||||
'error_message': str(e),
|
||||
'data_points': len(product_data) if product_data is not None else 0,
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'failed_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Report failure to progress tracker (still emits Event 3: product_completed)
|
||||
await progress_tracker.mark_product_completed(inventory_product_id)
|
||||
|
||||
return inventory_product_id, result
|
||||
|
||||
async def _train_all_models_enhanced(self,
|
||||
tenant_id: str,
|
||||
processed_data: Dict[str, pd.DataFrame],
|
||||
job_id: str,
|
||||
repos: Dict) -> Dict[str, Any]:
|
||||
"""Train models with enhanced repository integration"""
|
||||
training_results = {}
|
||||
i = 0
|
||||
repos: Dict,
|
||||
progress_tracker: ParallelProductProgressTracker) -> Dict[str, Any]:
|
||||
"""Train models with throttled parallel execution and progress tracking"""
|
||||
total_products = len(processed_data)
|
||||
base_progress = 45
|
||||
max_progress = 85
|
||||
logger.info(f"Starting throttled parallel training for {total_products} products")
|
||||
|
||||
# Create training tasks for all products
|
||||
training_tasks = [
|
||||
self._train_single_product(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
product_data=product_data,
|
||||
job_id=job_id,
|
||||
repos=repos,
|
||||
progress_tracker=progress_tracker
|
||||
)
|
||||
for inventory_product_id, product_data in processed_data.items()
|
||||
]
|
||||
|
||||
# Execute training tasks with throttling to prevent heartbeat blocking
|
||||
# Limit concurrent operations to prevent CPU/memory exhaustion
|
||||
from app.core.config import settings
|
||||
max_concurrent = getattr(settings, 'MAX_CONCURRENT_TRAININGS', 3)
|
||||
|
||||
for inventory_product_id, product_data in processed_data.items():
|
||||
product_start_time = time.time()
|
||||
try:
|
||||
logger.info("Training enhanced model",
|
||||
inventory_product_id=inventory_product_id)
|
||||
|
||||
# Check if we have enough data
|
||||
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||
training_results[inventory_product_id] = {
|
||||
'status': 'skipped',
|
||||
'reason': 'insufficient_data',
|
||||
'data_points': len(product_data),
|
||||
'min_required': settings.MIN_TRAINING_DATA_DAYS,
|
||||
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
|
||||
}
|
||||
logger.warning("Skipping product due to insufficient data",
|
||||
inventory_product_id=inventory_product_id,
|
||||
data_points=len(product_data),
|
||||
min_required=settings.MIN_TRAINING_DATA_DAYS)
|
||||
continue
|
||||
|
||||
# Train the model using Prophet manager
|
||||
model_info = await self.prophet_manager.train_bakery_model(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
df=product_data,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
# Store model record using repository
|
||||
model_record = await self._create_model_record(
|
||||
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
|
||||
)
|
||||
|
||||
# Create performance metrics record
|
||||
if model_info.get('training_metrics'):
|
||||
await self._create_performance_metrics(
|
||||
repos, model_record.id if model_record else None,
|
||||
tenant_id, inventory_product_id, model_info['training_metrics']
|
||||
)
|
||||
|
||||
training_results[inventory_product_id] = {
|
||||
'status': 'success',
|
||||
'model_info': model_info,
|
||||
'model_record_id': model_record.id if model_record else None,
|
||||
'data_points': len(product_data),
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'trained_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info("Successfully trained enhanced model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_record_id=model_record.id if model_record else None)
|
||||
|
||||
completed_products = i + 1
|
||||
i += 1
|
||||
progress = base_progress + int((completed_products / total_products) * (max_progress - base_progress))
|
||||
|
||||
if self.status_publisher:
|
||||
self.status_publisher.products_completed = completed_products
|
||||
|
||||
await self.status_publisher.progress_update(
|
||||
progress=progress,
|
||||
step="model_training",
|
||||
current_product=inventory_product_id,
|
||||
step_details=f"Enhanced training completed for {inventory_product_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to train enhanced model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
training_results[inventory_product_id] = {
|
||||
'status': 'error',
|
||||
'error_message': str(e),
|
||||
'data_points': len(product_data) if product_data is not None else 0,
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'failed_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
completed_products = i + 1
|
||||
i += 1
|
||||
progress = base_progress + int((completed_products / total_products) * (max_progress - base_progress))
|
||||
|
||||
if self.status_publisher:
|
||||
self.status_publisher.products_completed = completed_products
|
||||
await self.status_publisher.progress_update(
|
||||
progress=progress,
|
||||
step="model_training",
|
||||
current_product=inventory_product_id,
|
||||
step_details=f"Enhanced training failed for {inventory_product_id}: {str(e)}"
|
||||
)
|
||||
|
||||
logger.info(f"Executing training with max {max_concurrent} concurrent operations",
|
||||
total_products=total_products)
|
||||
|
||||
# Process tasks in batches to prevent blocking the event loop
|
||||
results_list = []
|
||||
for i in range(0, len(training_tasks), max_concurrent):
|
||||
batch = training_tasks[i:i + max_concurrent]
|
||||
batch_results = await asyncio.gather(*batch, return_exceptions=True)
|
||||
results_list.extend(batch_results)
|
||||
|
||||
# Yield control to event loop to allow heartbeat processing
|
||||
# Increased from 0.01s to 0.1s (100ms) to ensure WebSocket pings, RabbitMQ heartbeats,
|
||||
# and progress events can be processed during long training operations
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Log progress to verify event loop is responsive
|
||||
logger.debug(
|
||||
"Training batch completed, yielding to event loop",
|
||||
batch_num=(i // max_concurrent) + 1,
|
||||
total_batches=(len(training_tasks) + max_concurrent - 1) // max_concurrent,
|
||||
products_completed=len(results_list),
|
||||
total_products=len(training_tasks)
|
||||
)
|
||||
|
||||
# Log final summary
|
||||
summary = progress_tracker.get_progress()
|
||||
logger.info("Throttled parallel training completed",
|
||||
total=summary['total_products'],
|
||||
completed=summary['products_completed'])
|
||||
|
||||
# Convert results to dictionary
|
||||
training_results = {}
|
||||
for result in results_list:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Training task failed with exception: {result}")
|
||||
continue
|
||||
|
||||
product_id, product_result = result
|
||||
training_results[product_id] = product_result
|
||||
|
||||
logger.info(f"Throttled parallel training completed: {len(training_results)} products processed")
|
||||
return training_results
|
||||
|
||||
async def _create_model_record(self,
|
||||
@@ -655,7 +730,3 @@ class EnhancedBakeryMLTrainer:
|
||||
except Exception as e:
|
||||
logger.error("Enhanced model evaluation failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
# Legacy compatibility alias
|
||||
BakeryMLTrainer = EnhancedBakeryMLTrainer
|
||||
Reference in New Issue
Block a user