REFACTOR external service and improve websocket training

This commit is contained in:
Urtzi Alfaro
2025-10-09 14:11:02 +02:00
parent 7c72f83c51
commit 3c689b4f98
111 changed files with 13289 additions and 2374 deletions

View File

@@ -10,6 +10,7 @@ from datetime import datetime
import structlog
import uuid
import time
import asyncio
from app.ml.data_processor import EnhancedBakeryDataProcessor
from app.ml.prophet_manager import BakeryProphetManager
@@ -28,7 +29,13 @@ from app.repositories import (
ArtifactRepository
)
from app.services.messaging import TrainingStatusPublisher
from app.services.progress_tracker import ParallelProductProgressTracker
from app.services.training_events import (
publish_training_started,
publish_data_analysis,
publish_training_completed,
publish_training_failed
)
logger = structlog.get_logger()
@@ -75,8 +82,6 @@ class EnhancedBakeryMLTrainer:
job_id=job_id,
tenant_id=tenant_id)
self.status_publisher = TrainingStatusPublisher(job_id, tenant_id)
try:
# Get database session and repositories
async with self.database_manager.get_session() as db_session:
@@ -113,8 +118,10 @@ class EnhancedBakeryMLTrainer:
else:
logger.info("Multiple products detected for training",
products_count=len(products))
self.status_publisher.products_total = len(products)
# Event 1: Training Started (0%) - update with actual product count
# Note: Initial event was already published by API endpoint, this updates with real count
await publish_training_started(job_id, tenant_id, len(products))
# Create initial training log entry
await repos['training_log'].update_log_progress(
@@ -126,28 +133,45 @@ class EnhancedBakeryMLTrainer:
processed_data = await self._process_all_products_enhanced(
sales_df, weather_df, traffic_df, products, tenant_id, job_id
)
await self.status_publisher.progress_update(
progress=20,
step="feature_engineering",
step_details="Enhanced processing with repository tracking"
# Event 2: Data Analysis (20%)
await publish_data_analysis(
job_id,
tenant_id,
f"Data analysis completed for {len(processed_data)} products"
)
# Train models for each processed product
logger.info("Training models with repository integration")
# Train models for each processed product with progress aggregation
logger.info("Training models with repository integration and progress aggregation")
# Create progress tracker for parallel product training (20-80%)
progress_tracker = ParallelProductProgressTracker(
job_id=job_id,
tenant_id=tenant_id,
total_products=len(processed_data)
)
training_results = await self._train_all_models_enhanced(
tenant_id, processed_data, job_id, repos
tenant_id, processed_data, job_id, repos, progress_tracker
)
# Calculate overall training summary with enhanced metrics
summary = await self._calculate_enhanced_training_summary(
training_results, repos, tenant_id
)
await self.status_publisher.progress_update(
progress=90,
step="model_validation",
step_details="Enhanced validation with repository tracking"
# Calculate successful and failed trainings
successful_trainings = len([r for r in training_results.values() if r.get('status') == 'success'])
failed_trainings = len([r for r in training_results.values() if r.get('status') == 'error'])
total_duration = sum([r.get('training_time_seconds', 0) for r in training_results.values()])
# Event 4: Training Completed (100%)
await publish_training_completed(
job_id,
tenant_id,
successful_trainings,
failed_trainings,
total_duration
)
# Create comprehensive result with repository data
@@ -189,6 +213,10 @@ class EnhancedBakeryMLTrainer:
logger.error("Enhanced ML training pipeline failed",
job_id=job_id,
error=str(e))
# Publish training failed event
await publish_training_failed(job_id, tenant_id, str(e))
raise
async def _process_all_products_enhanced(self,
@@ -237,111 +265,158 @@ class EnhancedBakeryMLTrainer:
return processed_data
async def _train_single_product(self,
tenant_id: str,
inventory_product_id: str,
product_data: pd.DataFrame,
job_id: str,
repos: Dict,
progress_tracker: ParallelProductProgressTracker) -> tuple[str, Dict[str, Any]]:
"""Train a single product model - used for parallel execution with progress aggregation"""
product_start_time = time.time()
try:
logger.info("Training model", inventory_product_id=inventory_product_id)
# Check if we have enough data
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
result = {
'status': 'skipped',
'reason': 'insufficient_data',
'data_points': len(product_data),
'min_required': settings.MIN_TRAINING_DATA_DAYS,
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
}
logger.warning("Skipping product due to insufficient data",
inventory_product_id=inventory_product_id,
data_points=len(product_data),
min_required=settings.MIN_TRAINING_DATA_DAYS)
return inventory_product_id, result
# Train the model using Prophet manager
model_info = await self.prophet_manager.train_bakery_model(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
df=product_data,
job_id=job_id
)
# Store model record using repository
model_record = await self._create_model_record(
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
)
# Create performance metrics record
if model_info.get('training_metrics'):
await self._create_performance_metrics(
repos, model_record.id if model_record else None,
tenant_id, inventory_product_id, model_info['training_metrics']
)
result = {
'status': 'success',
'model_info': model_info,
'model_record_id': model_record.id if model_record else None,
'data_points': len(product_data),
'training_time_seconds': time.time() - product_start_time,
'trained_at': datetime.now().isoformat()
}
logger.info("Successfully trained model",
inventory_product_id=inventory_product_id,
model_record_id=model_record.id if model_record else None)
# Report completion to progress tracker (emits Event 3: product_completed)
await progress_tracker.mark_product_completed(inventory_product_id)
return inventory_product_id, result
except Exception as e:
logger.error("Failed to train model",
inventory_product_id=inventory_product_id,
error=str(e))
result = {
'status': 'error',
'error_message': str(e),
'data_points': len(product_data) if product_data is not None else 0,
'training_time_seconds': time.time() - product_start_time,
'failed_at': datetime.now().isoformat()
}
# Report failure to progress tracker (still emits Event 3: product_completed)
await progress_tracker.mark_product_completed(inventory_product_id)
return inventory_product_id, result
async def _train_all_models_enhanced(self,
tenant_id: str,
processed_data: Dict[str, pd.DataFrame],
job_id: str,
repos: Dict) -> Dict[str, Any]:
"""Train models with enhanced repository integration"""
training_results = {}
i = 0
repos: Dict,
progress_tracker: ParallelProductProgressTracker) -> Dict[str, Any]:
"""Train models with throttled parallel execution and progress tracking"""
total_products = len(processed_data)
base_progress = 45
max_progress = 85
logger.info(f"Starting throttled parallel training for {total_products} products")
# Create training tasks for all products
training_tasks = [
self._train_single_product(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
product_data=product_data,
job_id=job_id,
repos=repos,
progress_tracker=progress_tracker
)
for inventory_product_id, product_data in processed_data.items()
]
# Execute training tasks with throttling to prevent heartbeat blocking
# Limit concurrent operations to prevent CPU/memory exhaustion
from app.core.config import settings
max_concurrent = getattr(settings, 'MAX_CONCURRENT_TRAININGS', 3)
for inventory_product_id, product_data in processed_data.items():
product_start_time = time.time()
try:
logger.info("Training enhanced model",
inventory_product_id=inventory_product_id)
# Check if we have enough data
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
training_results[inventory_product_id] = {
'status': 'skipped',
'reason': 'insufficient_data',
'data_points': len(product_data),
'min_required': settings.MIN_TRAINING_DATA_DAYS,
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
}
logger.warning("Skipping product due to insufficient data",
inventory_product_id=inventory_product_id,
data_points=len(product_data),
min_required=settings.MIN_TRAINING_DATA_DAYS)
continue
# Train the model using Prophet manager
model_info = await self.prophet_manager.train_bakery_model(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
df=product_data,
job_id=job_id
)
# Store model record using repository
model_record = await self._create_model_record(
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
)
# Create performance metrics record
if model_info.get('training_metrics'):
await self._create_performance_metrics(
repos, model_record.id if model_record else None,
tenant_id, inventory_product_id, model_info['training_metrics']
)
training_results[inventory_product_id] = {
'status': 'success',
'model_info': model_info,
'model_record_id': model_record.id if model_record else None,
'data_points': len(product_data),
'training_time_seconds': time.time() - product_start_time,
'trained_at': datetime.now().isoformat()
}
logger.info("Successfully trained enhanced model",
inventory_product_id=inventory_product_id,
model_record_id=model_record.id if model_record else None)
completed_products = i + 1
i += 1
progress = base_progress + int((completed_products / total_products) * (max_progress - base_progress))
if self.status_publisher:
self.status_publisher.products_completed = completed_products
await self.status_publisher.progress_update(
progress=progress,
step="model_training",
current_product=inventory_product_id,
step_details=f"Enhanced training completed for {inventory_product_id}"
)
except Exception as e:
logger.error("Failed to train enhanced model",
inventory_product_id=inventory_product_id,
error=str(e))
training_results[inventory_product_id] = {
'status': 'error',
'error_message': str(e),
'data_points': len(product_data) if product_data is not None else 0,
'training_time_seconds': time.time() - product_start_time,
'failed_at': datetime.now().isoformat()
}
completed_products = i + 1
i += 1
progress = base_progress + int((completed_products / total_products) * (max_progress - base_progress))
if self.status_publisher:
self.status_publisher.products_completed = completed_products
await self.status_publisher.progress_update(
progress=progress,
step="model_training",
current_product=inventory_product_id,
step_details=f"Enhanced training failed for {inventory_product_id}: {str(e)}"
)
logger.info(f"Executing training with max {max_concurrent} concurrent operations",
total_products=total_products)
# Process tasks in batches to prevent blocking the event loop
results_list = []
for i in range(0, len(training_tasks), max_concurrent):
batch = training_tasks[i:i + max_concurrent]
batch_results = await asyncio.gather(*batch, return_exceptions=True)
results_list.extend(batch_results)
# Yield control to event loop to allow heartbeat processing
# Increased from 0.01s to 0.1s (100ms) to ensure WebSocket pings, RabbitMQ heartbeats,
# and progress events can be processed during long training operations
await asyncio.sleep(0.1)
# Log progress to verify event loop is responsive
logger.debug(
"Training batch completed, yielding to event loop",
batch_num=(i // max_concurrent) + 1,
total_batches=(len(training_tasks) + max_concurrent - 1) // max_concurrent,
products_completed=len(results_list),
total_products=len(training_tasks)
)
# Log final summary
summary = progress_tracker.get_progress()
logger.info("Throttled parallel training completed",
total=summary['total_products'],
completed=summary['products_completed'])
# Convert results to dictionary
training_results = {}
for result in results_list:
if isinstance(result, Exception):
logger.error(f"Training task failed with exception: {result}")
continue
product_id, product_result = result
training_results[product_id] = product_result
logger.info(f"Throttled parallel training completed: {len(training_results)} products processed")
return training_results
async def _create_model_record(self,
@@ -655,7 +730,3 @@ class EnhancedBakeryMLTrainer:
except Exception as e:
logger.error("Enhanced model evaluation failed", error=str(e))
raise
# Legacy compatibility alias
BakeryMLTrainer = EnhancedBakeryMLTrainer