REFACTOR external service and improve websocket training
This commit is contained in:
@@ -3,16 +3,12 @@ ML Pipeline Components
|
||||
Machine learning training and prediction components
|
||||
"""
|
||||
|
||||
from .trainer import BakeryMLTrainer
|
||||
from .trainer import EnhancedBakeryMLTrainer
|
||||
from .data_processor import BakeryDataProcessor
|
||||
from .data_processor import EnhancedBakeryDataProcessor
|
||||
from .prophet_manager import BakeryProphetManager
|
||||
|
||||
__all__ = [
|
||||
"BakeryMLTrainer",
|
||||
"EnhancedBakeryMLTrainer",
|
||||
"BakeryDataProcessor",
|
||||
"EnhancedBakeryDataProcessor",
|
||||
"BakeryProphetManager"
|
||||
]
|
||||
@@ -865,8 +865,4 @@ class EnhancedBakeryDataProcessor:
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error generating data quality report", error=str(e))
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# Legacy compatibility alias
|
||||
BakeryDataProcessor = EnhancedBakeryDataProcessor
|
||||
return {"error": str(e)}
|
||||
@@ -32,6 +32,10 @@ import optuna
|
||||
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core import constants as const
|
||||
from app.utils.timezone_utils import prepare_prophet_datetime
|
||||
from app.utils.file_utils import ChecksummedFile, calculate_file_checksum
|
||||
from app.utils.distributed_lock import get_training_lock, LockAcquisitionError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,72 +54,79 @@ class BakeryProphetManager:
|
||||
# Ensure model storage directory exists
|
||||
os.makedirs(settings.MODEL_STORAGE_PATH, exist_ok=True)
|
||||
|
||||
async def train_bakery_model(self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
async def train_bakery_model(self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
df: pd.DataFrame,
|
||||
job_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Train a Prophet model with automatic hyperparameter optimization.
|
||||
Same interface as before - optimization happens automatically.
|
||||
Train a Prophet model with automatic hyperparameter optimization and distributed locking.
|
||||
"""
|
||||
# Acquire distributed lock to prevent concurrent training of same product
|
||||
lock = get_training_lock(tenant_id, inventory_product_id, use_advisory=True)
|
||||
|
||||
try:
|
||||
logger.info(f"Training optimized bakery model for {inventory_product_id}")
|
||||
|
||||
# Validate input data
|
||||
await self._validate_training_data(df, inventory_product_id)
|
||||
|
||||
# Prepare data for Prophet
|
||||
prophet_data = await self._prepare_prophet_data(df)
|
||||
|
||||
# Get regressor columns
|
||||
regressor_columns = self._extract_regressor_columns(prophet_data)
|
||||
|
||||
# Automatically optimize hyperparameters (this is the new part)
|
||||
logger.info(f"Optimizing hyperparameters for {inventory_product_id}...")
|
||||
best_params = await self._optimize_hyperparameters(prophet_data, inventory_product_id, regressor_columns)
|
||||
|
||||
# Create optimized Prophet model
|
||||
model = self._create_optimized_prophet_model(best_params, regressor_columns)
|
||||
|
||||
# Add regressors to model
|
||||
for regressor in regressor_columns:
|
||||
if regressor in prophet_data.columns:
|
||||
model.add_regressor(regressor)
|
||||
|
||||
# Fit the model
|
||||
model.fit(prophet_data)
|
||||
|
||||
# Calculate enhanced training metrics first
|
||||
training_metrics = await self._calculate_training_metrics(model, prophet_data, best_params)
|
||||
|
||||
# Store model and metrics - Generate proper UUID for model_id
|
||||
model_id = str(uuid.uuid4())
|
||||
model_path = await self._store_model(
|
||||
tenant_id, inventory_product_id, model, model_id, prophet_data, regressor_columns, best_params, training_metrics
|
||||
)
|
||||
|
||||
# Return same format as before, but with optimization info
|
||||
model_info = {
|
||||
"model_id": model_id,
|
||||
"model_path": model_path,
|
||||
"type": "prophet_optimized", # Changed from "prophet"
|
||||
"training_samples": len(prophet_data),
|
||||
"features": regressor_columns,
|
||||
"hyperparameters": best_params, # Now contains optimized params
|
||||
"training_metrics": training_metrics,
|
||||
"trained_at": datetime.now().isoformat(),
|
||||
"data_period": {
|
||||
"start_date": prophet_data['ds'].min().isoformat(),
|
||||
"end_date": prophet_data['ds'].max().isoformat(),
|
||||
"total_days": len(prophet_data)
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"Optimized model trained successfully for {inventory_product_id}. "
|
||||
f"MAPE: {training_metrics.get('optimized_mape', 'N/A')}%")
|
||||
return model_info
|
||||
|
||||
async with self.database_manager.get_session() as session:
|
||||
async with lock.acquire(session):
|
||||
logger.info(f"Training optimized bakery model for {inventory_product_id} (lock acquired)")
|
||||
|
||||
# Validate input data
|
||||
await self._validate_training_data(df, inventory_product_id)
|
||||
|
||||
# Prepare data for Prophet
|
||||
prophet_data = await self._prepare_prophet_data(df)
|
||||
|
||||
# Get regressor columns
|
||||
regressor_columns = self._extract_regressor_columns(prophet_data)
|
||||
|
||||
# Automatically optimize hyperparameters
|
||||
logger.info(f"Optimizing hyperparameters for {inventory_product_id}...")
|
||||
best_params = await self._optimize_hyperparameters(prophet_data, inventory_product_id, regressor_columns)
|
||||
|
||||
# Create optimized Prophet model
|
||||
model = self._create_optimized_prophet_model(best_params, regressor_columns)
|
||||
|
||||
# Add regressors to model
|
||||
for regressor in regressor_columns:
|
||||
if regressor in prophet_data.columns:
|
||||
model.add_regressor(regressor)
|
||||
|
||||
# Fit the model
|
||||
model.fit(prophet_data)
|
||||
|
||||
# Calculate enhanced training metrics first
|
||||
training_metrics = await self._calculate_training_metrics(model, prophet_data, best_params)
|
||||
|
||||
# Store model and metrics - Generate proper UUID for model_id
|
||||
model_id = str(uuid.uuid4())
|
||||
model_path = await self._store_model(
|
||||
tenant_id, inventory_product_id, model, model_id, prophet_data, regressor_columns, best_params, training_metrics
|
||||
)
|
||||
|
||||
# Return same format as before, but with optimization info
|
||||
model_info = {
|
||||
"model_id": model_id,
|
||||
"model_path": model_path,
|
||||
"type": "prophet_optimized",
|
||||
"training_samples": len(prophet_data),
|
||||
"features": regressor_columns,
|
||||
"hyperparameters": best_params,
|
||||
"training_metrics": training_metrics,
|
||||
"trained_at": datetime.now().isoformat(),
|
||||
"data_period": {
|
||||
"start_date": prophet_data['ds'].min().isoformat(),
|
||||
"end_date": prophet_data['ds'].max().isoformat(),
|
||||
"total_days": len(prophet_data)
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"Optimized model trained successfully for {inventory_product_id}. "
|
||||
f"MAPE: {training_metrics.get('optimized_mape', 'N/A')}%")
|
||||
return model_info
|
||||
|
||||
except LockAcquisitionError as e:
|
||||
logger.warning(f"Could not acquire lock for {inventory_product_id}: {e}")
|
||||
raise RuntimeError(f"Training already in progress for product {inventory_product_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to train optimized bakery model for {inventory_product_id}: {str(e)}")
|
||||
raise
|
||||
@@ -134,11 +145,11 @@ class BakeryProphetManager:
|
||||
|
||||
# Set optimization parameters based on category
|
||||
n_trials = {
|
||||
'high_volume': 30, # Reduced from 75 for speed
|
||||
'medium_volume': 25, # Reduced from 50
|
||||
'low_volume': 20, # Reduced from 30
|
||||
'intermittent': 15 # Reduced from 25
|
||||
}.get(product_category, 25)
|
||||
'high_volume': const.OPTUNA_TRIALS_HIGH_VOLUME,
|
||||
'medium_volume': const.OPTUNA_TRIALS_MEDIUM_VOLUME,
|
||||
'low_volume': const.OPTUNA_TRIALS_LOW_VOLUME,
|
||||
'intermittent': const.OPTUNA_TRIALS_INTERMITTENT
|
||||
}.get(product_category, const.OPTUNA_TRIALS_MEDIUM_VOLUME)
|
||||
|
||||
logger.info(f"Product {inventory_product_id} classified as {product_category}, using {n_trials} trials")
|
||||
|
||||
@@ -152,7 +163,7 @@ class BakeryProphetManager:
|
||||
f"zero_ratio={zero_ratio:.2f}, mean_sales={mean_sales:.2f}, non_zero_days={non_zero_days}")
|
||||
|
||||
# Adjust strategy based on data characteristics
|
||||
if zero_ratio > 0.8 or non_zero_days < 30:
|
||||
if zero_ratio > const.MAX_ZERO_RATIO_INTERMITTENT or non_zero_days < const.MIN_NON_ZERO_DAYS:
|
||||
logger.warning(f"Very sparse data for {inventory_product_id}, using minimal optimization")
|
||||
return {
|
||||
'changepoint_prior_scale': 0.001,
|
||||
@@ -163,9 +174,9 @@ class BakeryProphetManager:
|
||||
'daily_seasonality': False,
|
||||
'weekly_seasonality': True,
|
||||
'yearly_seasonality': False,
|
||||
'uncertainty_samples': 100 # ✅ FIX: Minimal uncertainty sampling for very sparse data
|
||||
'uncertainty_samples': const.UNCERTAINTY_SAMPLES_SPARSE_MIN
|
||||
}
|
||||
elif zero_ratio > 0.6:
|
||||
elif zero_ratio > const.MODERATE_SPARSITY_THRESHOLD:
|
||||
logger.info(f"Moderate sparsity for {inventory_product_id}, using conservative optimization")
|
||||
return {
|
||||
'changepoint_prior_scale': 0.01,
|
||||
@@ -175,8 +186,8 @@ class BakeryProphetManager:
|
||||
'seasonality_mode': 'additive',
|
||||
'daily_seasonality': False,
|
||||
'weekly_seasonality': True,
|
||||
'yearly_seasonality': len(df) > 365, # Only if we have enough data
|
||||
'uncertainty_samples': 200 # ✅ FIX: Conservative uncertainty sampling for moderately sparse data
|
||||
'yearly_seasonality': len(df) > const.DATA_QUALITY_DAY_THRESHOLD_HIGH,
|
||||
'uncertainty_samples': const.UNCERTAINTY_SAMPLES_SPARSE_MAX
|
||||
}
|
||||
|
||||
# Use unique seed for each product to avoid identical results
|
||||
@@ -198,15 +209,15 @@ class BakeryProphetManager:
|
||||
changepoint_scale_range = (0.001, 0.5)
|
||||
seasonality_scale_range = (0.01, 10.0)
|
||||
|
||||
# ✅ FIX: Determine appropriate uncertainty samples range based on product category
|
||||
# Determine appropriate uncertainty samples range based on product category
|
||||
if product_category == 'high_volume':
|
||||
uncertainty_range = (300, 800) # More samples for stable high-volume products
|
||||
uncertainty_range = (const.UNCERTAINTY_SAMPLES_HIGH_MIN, const.UNCERTAINTY_SAMPLES_HIGH_MAX)
|
||||
elif product_category == 'medium_volume':
|
||||
uncertainty_range = (200, 500) # Moderate samples for medium volume
|
||||
uncertainty_range = (const.UNCERTAINTY_SAMPLES_MEDIUM_MIN, const.UNCERTAINTY_SAMPLES_MEDIUM_MAX)
|
||||
elif product_category == 'low_volume':
|
||||
uncertainty_range = (150, 300) # Fewer samples for low volume
|
||||
uncertainty_range = (const.UNCERTAINTY_SAMPLES_LOW_MIN, const.UNCERTAINTY_SAMPLES_LOW_MAX)
|
||||
else: # intermittent
|
||||
uncertainty_range = (100, 200) # Minimal samples for intermittent demand
|
||||
uncertainty_range = (const.UNCERTAINTY_SAMPLES_SPARSE_MIN, const.UNCERTAINTY_SAMPLES_SPARSE_MAX)
|
||||
|
||||
params = {
|
||||
'changepoint_prior_scale': trial.suggest_float(
|
||||
@@ -295,10 +306,10 @@ class BakeryProphetManager:
|
||||
|
||||
# Run optimization with product-specific seed
|
||||
study = optuna.create_study(
|
||||
direction='minimize',
|
||||
sampler=optuna.samplers.TPESampler(seed=product_seed) # Unique seed per product
|
||||
direction='minimize',
|
||||
sampler=optuna.samplers.TPESampler(seed=product_seed)
|
||||
)
|
||||
study.optimize(objective, n_trials=n_trials, timeout=600, show_progress_bar=False)
|
||||
study.optimize(objective, n_trials=n_trials, timeout=const.OPTUNA_TIMEOUT_SECONDS, show_progress_bar=False)
|
||||
|
||||
# Return best parameters
|
||||
best_params = study.best_params
|
||||
@@ -515,8 +526,12 @@ class BakeryProphetManager:
|
||||
# Store model file
|
||||
model_path = model_dir / f"{model_id}.pkl"
|
||||
joblib.dump(model, model_path)
|
||||
|
||||
# Enhanced metadata
|
||||
|
||||
# Calculate checksum for model file integrity
|
||||
checksummed_file = ChecksummedFile(str(model_path))
|
||||
model_checksum = checksummed_file.calculate_and_save_checksum()
|
||||
|
||||
# Enhanced metadata with checksum
|
||||
metadata = {
|
||||
"model_id": model_id,
|
||||
"tenant_id": tenant_id,
|
||||
@@ -531,9 +546,11 @@ class BakeryProphetManager:
|
||||
"optimized_parameters": optimized_params or {},
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"model_type": "prophet_optimized",
|
||||
"file_path": str(model_path)
|
||||
"file_path": str(model_path),
|
||||
"checksum": model_checksum,
|
||||
"checksum_algorithm": "sha256"
|
||||
}
|
||||
|
||||
|
||||
metadata_path = model_path.with_suffix('.json')
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(metadata, f, indent=2, default=str)
|
||||
@@ -609,23 +626,29 @@ class BakeryProphetManager:
|
||||
logger.error(f"Failed to deactivate previous models: {str(e)}")
|
||||
raise
|
||||
|
||||
# Keep all existing methods unchanged
|
||||
async def generate_forecast(self,
|
||||
async def generate_forecast(self,
|
||||
model_path: str,
|
||||
future_dates: pd.DataFrame,
|
||||
regressor_columns: List[str]) -> pd.DataFrame:
|
||||
"""Generate forecast using stored model (unchanged)"""
|
||||
"""Generate forecast using stored model with checksum verification"""
|
||||
try:
|
||||
# Verify model file integrity before loading
|
||||
checksummed_file = ChecksummedFile(model_path)
|
||||
if not checksummed_file.load_and_verify_checksum():
|
||||
logger.warning(f"Checksum verification failed for model: {model_path}")
|
||||
# Still load the model but log warning
|
||||
# In production, you might want to raise an exception instead
|
||||
|
||||
model = joblib.load(model_path)
|
||||
|
||||
|
||||
for regressor in regressor_columns:
|
||||
if regressor not in future_dates.columns:
|
||||
logger.warning(f"Missing regressor {regressor}, filling with median")
|
||||
future_dates[regressor] = 0
|
||||
|
||||
|
||||
forecast = model.predict(future_dates)
|
||||
return forecast
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate forecast: {str(e)}")
|
||||
raise
|
||||
@@ -655,34 +678,28 @@ class BakeryProphetManager:
|
||||
async def _prepare_prophet_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Prepare data for Prophet training with timezone handling"""
|
||||
prophet_data = df.copy()
|
||||
|
||||
|
||||
if 'ds' not in prophet_data.columns:
|
||||
raise ValueError("Missing 'ds' column in training data")
|
||||
if 'y' not in prophet_data.columns:
|
||||
raise ValueError("Missing 'y' column in training data")
|
||||
|
||||
# Convert to datetime and remove timezone information
|
||||
prophet_data['ds'] = pd.to_datetime(prophet_data['ds'])
|
||||
|
||||
# Remove timezone if present (Prophet doesn't support timezones)
|
||||
if prophet_data['ds'].dt.tz is not None:
|
||||
logger.info("Removing timezone information from 'ds' column for Prophet compatibility")
|
||||
prophet_data['ds'] = prophet_data['ds'].dt.tz_localize(None)
|
||||
|
||||
|
||||
# Use timezone utility to prepare Prophet-compatible datetime
|
||||
prophet_data = prepare_prophet_datetime(prophet_data, 'ds')
|
||||
|
||||
# Sort by date and clean data
|
||||
prophet_data = prophet_data.sort_values('ds').reset_index(drop=True)
|
||||
prophet_data['y'] = pd.to_numeric(prophet_data['y'], errors='coerce')
|
||||
prophet_data = prophet_data.dropna(subset=['y'])
|
||||
|
||||
# Additional data cleaning for Prophet
|
||||
|
||||
# Remove any duplicate dates (keep last occurrence)
|
||||
prophet_data = prophet_data.drop_duplicates(subset=['ds'], keep='last')
|
||||
|
||||
# Ensure y values are non-negative (Prophet works better with non-negative values)
|
||||
|
||||
# Ensure y values are non-negative
|
||||
prophet_data['y'] = prophet_data['y'].clip(lower=0)
|
||||
|
||||
|
||||
logger.info(f"Prepared Prophet data: {len(prophet_data)} rows, date range: {prophet_data['ds'].min()} to {prophet_data['ds'].max()}")
|
||||
|
||||
|
||||
return prophet_data
|
||||
|
||||
def _extract_regressor_columns(self, df: pd.DataFrame) -> List[str]:
|
||||
|
||||
@@ -10,6 +10,7 @@ from datetime import datetime
|
||||
import structlog
|
||||
import uuid
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from app.ml.data_processor import EnhancedBakeryDataProcessor
|
||||
from app.ml.prophet_manager import BakeryProphetManager
|
||||
@@ -28,7 +29,13 @@ from app.repositories import (
|
||||
ArtifactRepository
|
||||
)
|
||||
|
||||
from app.services.messaging import TrainingStatusPublisher
|
||||
from app.services.progress_tracker import ParallelProductProgressTracker
|
||||
from app.services.training_events import (
|
||||
publish_training_started,
|
||||
publish_data_analysis,
|
||||
publish_training_completed,
|
||||
publish_training_failed
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -75,8 +82,6 @@ class EnhancedBakeryMLTrainer:
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
self.status_publisher = TrainingStatusPublisher(job_id, tenant_id)
|
||||
|
||||
try:
|
||||
# Get database session and repositories
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
@@ -113,8 +118,10 @@ class EnhancedBakeryMLTrainer:
|
||||
else:
|
||||
logger.info("Multiple products detected for training",
|
||||
products_count=len(products))
|
||||
|
||||
self.status_publisher.products_total = len(products)
|
||||
|
||||
# Event 1: Training Started (0%) - update with actual product count
|
||||
# Note: Initial event was already published by API endpoint, this updates with real count
|
||||
await publish_training_started(job_id, tenant_id, len(products))
|
||||
|
||||
# Create initial training log entry
|
||||
await repos['training_log'].update_log_progress(
|
||||
@@ -126,28 +133,45 @@ class EnhancedBakeryMLTrainer:
|
||||
processed_data = await self._process_all_products_enhanced(
|
||||
sales_df, weather_df, traffic_df, products, tenant_id, job_id
|
||||
)
|
||||
|
||||
await self.status_publisher.progress_update(
|
||||
progress=20,
|
||||
step="feature_engineering",
|
||||
step_details="Enhanced processing with repository tracking"
|
||||
|
||||
# Event 2: Data Analysis (20%)
|
||||
await publish_data_analysis(
|
||||
job_id,
|
||||
tenant_id,
|
||||
f"Data analysis completed for {len(processed_data)} products"
|
||||
)
|
||||
|
||||
# Train models for each processed product
|
||||
logger.info("Training models with repository integration")
|
||||
# Train models for each processed product with progress aggregation
|
||||
logger.info("Training models with repository integration and progress aggregation")
|
||||
|
||||
# Create progress tracker for parallel product training (20-80%)
|
||||
progress_tracker = ParallelProductProgressTracker(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=len(processed_data)
|
||||
)
|
||||
|
||||
training_results = await self._train_all_models_enhanced(
|
||||
tenant_id, processed_data, job_id, repos
|
||||
tenant_id, processed_data, job_id, repos, progress_tracker
|
||||
)
|
||||
|
||||
# Calculate overall training summary with enhanced metrics
|
||||
summary = await self._calculate_enhanced_training_summary(
|
||||
training_results, repos, tenant_id
|
||||
)
|
||||
|
||||
await self.status_publisher.progress_update(
|
||||
progress=90,
|
||||
step="model_validation",
|
||||
step_details="Enhanced validation with repository tracking"
|
||||
|
||||
# Calculate successful and failed trainings
|
||||
successful_trainings = len([r for r in training_results.values() if r.get('status') == 'success'])
|
||||
failed_trainings = len([r for r in training_results.values() if r.get('status') == 'error'])
|
||||
total_duration = sum([r.get('training_time_seconds', 0) for r in training_results.values()])
|
||||
|
||||
# Event 4: Training Completed (100%)
|
||||
await publish_training_completed(
|
||||
job_id,
|
||||
tenant_id,
|
||||
successful_trainings,
|
||||
failed_trainings,
|
||||
total_duration
|
||||
)
|
||||
|
||||
# Create comprehensive result with repository data
|
||||
@@ -189,6 +213,10 @@ class EnhancedBakeryMLTrainer:
|
||||
logger.error("Enhanced ML training pipeline failed",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
|
||||
# Publish training failed event
|
||||
await publish_training_failed(job_id, tenant_id, str(e))
|
||||
|
||||
raise
|
||||
|
||||
async def _process_all_products_enhanced(self,
|
||||
@@ -237,111 +265,158 @@ class EnhancedBakeryMLTrainer:
|
||||
|
||||
return processed_data
|
||||
|
||||
async def _train_single_product(self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
product_data: pd.DataFrame,
|
||||
job_id: str,
|
||||
repos: Dict,
|
||||
progress_tracker: ParallelProductProgressTracker) -> tuple[str, Dict[str, Any]]:
|
||||
"""Train a single product model - used for parallel execution with progress aggregation"""
|
||||
product_start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.info("Training model", inventory_product_id=inventory_product_id)
|
||||
|
||||
# Check if we have enough data
|
||||
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||
result = {
|
||||
'status': 'skipped',
|
||||
'reason': 'insufficient_data',
|
||||
'data_points': len(product_data),
|
||||
'min_required': settings.MIN_TRAINING_DATA_DAYS,
|
||||
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
|
||||
}
|
||||
logger.warning("Skipping product due to insufficient data",
|
||||
inventory_product_id=inventory_product_id,
|
||||
data_points=len(product_data),
|
||||
min_required=settings.MIN_TRAINING_DATA_DAYS)
|
||||
return inventory_product_id, result
|
||||
|
||||
# Train the model using Prophet manager
|
||||
model_info = await self.prophet_manager.train_bakery_model(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
df=product_data,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
# Store model record using repository
|
||||
model_record = await self._create_model_record(
|
||||
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
|
||||
)
|
||||
|
||||
# Create performance metrics record
|
||||
if model_info.get('training_metrics'):
|
||||
await self._create_performance_metrics(
|
||||
repos, model_record.id if model_record else None,
|
||||
tenant_id, inventory_product_id, model_info['training_metrics']
|
||||
)
|
||||
|
||||
result = {
|
||||
'status': 'success',
|
||||
'model_info': model_info,
|
||||
'model_record_id': model_record.id if model_record else None,
|
||||
'data_points': len(product_data),
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'trained_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info("Successfully trained model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_record_id=model_record.id if model_record else None)
|
||||
|
||||
# Report completion to progress tracker (emits Event 3: product_completed)
|
||||
await progress_tracker.mark_product_completed(inventory_product_id)
|
||||
|
||||
return inventory_product_id, result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to train model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
result = {
|
||||
'status': 'error',
|
||||
'error_message': str(e),
|
||||
'data_points': len(product_data) if product_data is not None else 0,
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'failed_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Report failure to progress tracker (still emits Event 3: product_completed)
|
||||
await progress_tracker.mark_product_completed(inventory_product_id)
|
||||
|
||||
return inventory_product_id, result
|
||||
|
||||
async def _train_all_models_enhanced(self,
|
||||
tenant_id: str,
|
||||
processed_data: Dict[str, pd.DataFrame],
|
||||
job_id: str,
|
||||
repos: Dict) -> Dict[str, Any]:
|
||||
"""Train models with enhanced repository integration"""
|
||||
training_results = {}
|
||||
i = 0
|
||||
repos: Dict,
|
||||
progress_tracker: ParallelProductProgressTracker) -> Dict[str, Any]:
|
||||
"""Train models with throttled parallel execution and progress tracking"""
|
||||
total_products = len(processed_data)
|
||||
base_progress = 45
|
||||
max_progress = 85
|
||||
logger.info(f"Starting throttled parallel training for {total_products} products")
|
||||
|
||||
# Create training tasks for all products
|
||||
training_tasks = [
|
||||
self._train_single_product(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
product_data=product_data,
|
||||
job_id=job_id,
|
||||
repos=repos,
|
||||
progress_tracker=progress_tracker
|
||||
)
|
||||
for inventory_product_id, product_data in processed_data.items()
|
||||
]
|
||||
|
||||
# Execute training tasks with throttling to prevent heartbeat blocking
|
||||
# Limit concurrent operations to prevent CPU/memory exhaustion
|
||||
from app.core.config import settings
|
||||
max_concurrent = getattr(settings, 'MAX_CONCURRENT_TRAININGS', 3)
|
||||
|
||||
for inventory_product_id, product_data in processed_data.items():
|
||||
product_start_time = time.time()
|
||||
try:
|
||||
logger.info("Training enhanced model",
|
||||
inventory_product_id=inventory_product_id)
|
||||
|
||||
# Check if we have enough data
|
||||
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||
training_results[inventory_product_id] = {
|
||||
'status': 'skipped',
|
||||
'reason': 'insufficient_data',
|
||||
'data_points': len(product_data),
|
||||
'min_required': settings.MIN_TRAINING_DATA_DAYS,
|
||||
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
|
||||
}
|
||||
logger.warning("Skipping product due to insufficient data",
|
||||
inventory_product_id=inventory_product_id,
|
||||
data_points=len(product_data),
|
||||
min_required=settings.MIN_TRAINING_DATA_DAYS)
|
||||
continue
|
||||
|
||||
# Train the model using Prophet manager
|
||||
model_info = await self.prophet_manager.train_bakery_model(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
df=product_data,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
# Store model record using repository
|
||||
model_record = await self._create_model_record(
|
||||
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
|
||||
)
|
||||
|
||||
# Create performance metrics record
|
||||
if model_info.get('training_metrics'):
|
||||
await self._create_performance_metrics(
|
||||
repos, model_record.id if model_record else None,
|
||||
tenant_id, inventory_product_id, model_info['training_metrics']
|
||||
)
|
||||
|
||||
training_results[inventory_product_id] = {
|
||||
'status': 'success',
|
||||
'model_info': model_info,
|
||||
'model_record_id': model_record.id if model_record else None,
|
||||
'data_points': len(product_data),
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'trained_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info("Successfully trained enhanced model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_record_id=model_record.id if model_record else None)
|
||||
|
||||
completed_products = i + 1
|
||||
i += 1
|
||||
progress = base_progress + int((completed_products / total_products) * (max_progress - base_progress))
|
||||
|
||||
if self.status_publisher:
|
||||
self.status_publisher.products_completed = completed_products
|
||||
|
||||
await self.status_publisher.progress_update(
|
||||
progress=progress,
|
||||
step="model_training",
|
||||
current_product=inventory_product_id,
|
||||
step_details=f"Enhanced training completed for {inventory_product_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to train enhanced model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
training_results[inventory_product_id] = {
|
||||
'status': 'error',
|
||||
'error_message': str(e),
|
||||
'data_points': len(product_data) if product_data is not None else 0,
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'failed_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
completed_products = i + 1
|
||||
i += 1
|
||||
progress = base_progress + int((completed_products / total_products) * (max_progress - base_progress))
|
||||
|
||||
if self.status_publisher:
|
||||
self.status_publisher.products_completed = completed_products
|
||||
await self.status_publisher.progress_update(
|
||||
progress=progress,
|
||||
step="model_training",
|
||||
current_product=inventory_product_id,
|
||||
step_details=f"Enhanced training failed for {inventory_product_id}: {str(e)}"
|
||||
)
|
||||
|
||||
logger.info(f"Executing training with max {max_concurrent} concurrent operations",
|
||||
total_products=total_products)
|
||||
|
||||
# Process tasks in batches to prevent blocking the event loop
|
||||
results_list = []
|
||||
for i in range(0, len(training_tasks), max_concurrent):
|
||||
batch = training_tasks[i:i + max_concurrent]
|
||||
batch_results = await asyncio.gather(*batch, return_exceptions=True)
|
||||
results_list.extend(batch_results)
|
||||
|
||||
# Yield control to event loop to allow heartbeat processing
|
||||
# Increased from 0.01s to 0.1s (100ms) to ensure WebSocket pings, RabbitMQ heartbeats,
|
||||
# and progress events can be processed during long training operations
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Log progress to verify event loop is responsive
|
||||
logger.debug(
|
||||
"Training batch completed, yielding to event loop",
|
||||
batch_num=(i // max_concurrent) + 1,
|
||||
total_batches=(len(training_tasks) + max_concurrent - 1) // max_concurrent,
|
||||
products_completed=len(results_list),
|
||||
total_products=len(training_tasks)
|
||||
)
|
||||
|
||||
# Log final summary
|
||||
summary = progress_tracker.get_progress()
|
||||
logger.info("Throttled parallel training completed",
|
||||
total=summary['total_products'],
|
||||
completed=summary['products_completed'])
|
||||
|
||||
# Convert results to dictionary
|
||||
training_results = {}
|
||||
for result in results_list:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Training task failed with exception: {result}")
|
||||
continue
|
||||
|
||||
product_id, product_result = result
|
||||
training_results[product_id] = product_result
|
||||
|
||||
logger.info(f"Throttled parallel training completed: {len(training_results)} products processed")
|
||||
return training_results
|
||||
|
||||
async def _create_model_record(self,
|
||||
@@ -655,7 +730,3 @@ class EnhancedBakeryMLTrainer:
|
||||
except Exception as e:
|
||||
logger.error("Enhanced model evaluation failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
# Legacy compatibility alias
|
||||
BakeryMLTrainer = EnhancedBakeryMLTrainer
|
||||
Reference in New Issue
Block a user