1069 lines
52 KiB
Python
1069 lines
52 KiB
Python
"""
|
|
Enhanced ML Trainer with Repository Pattern
|
|
Main ML pipeline coordinator using repository pattern for data access and dependency injection
|
|
"""
|
|
|
|
from typing import Dict, List, Any, Optional
|
|
import pandas as pd
|
|
import numpy as np
|
|
from datetime import datetime
|
|
import structlog
|
|
import uuid
|
|
import time
|
|
import asyncio
|
|
|
|
from app.ml.data_processor import EnhancedBakeryDataProcessor
|
|
from app.ml.prophet_manager import BakeryProphetManager
|
|
from app.ml.product_categorizer import ProductCategorizer, ProductCategory
|
|
from app.ml.model_selector import ModelSelector
|
|
from app.ml.hybrid_trainer import HybridProphetXGBoost
|
|
from app.services.training_orchestrator import TrainingDataSet
|
|
from app.core.config import settings
|
|
|
|
from shared.database.base import create_database_manager
|
|
from shared.database.transactions import transactional
|
|
from shared.database.unit_of_work import UnitOfWork
|
|
from shared.database.exceptions import DatabaseError
|
|
|
|
from app.repositories import (
|
|
ModelRepository,
|
|
TrainingLogRepository,
|
|
PerformanceRepository,
|
|
ArtifactRepository
|
|
)
|
|
|
|
from app.services.progress_tracker import ParallelProductProgressTracker
|
|
from app.services.training_events import (
|
|
publish_training_started,
|
|
publish_data_analysis,
|
|
publish_training_completed,
|
|
publish_training_failed
|
|
)
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
class EnhancedBakeryMLTrainer:
|
|
"""
|
|
Enhanced ML trainer using repository pattern for data access and comprehensive tracking.
|
|
Orchestrates the complete ML training pipeline with proper database abstraction.
|
|
"""
|
|
|
|
def __init__(self, database_manager=None):
|
|
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "training-service")
|
|
self.enhanced_data_processor = EnhancedBakeryDataProcessor(self.database_manager)
|
|
self.prophet_manager = BakeryProphetManager(database_manager=self.database_manager)
|
|
self.hybrid_trainer = HybridProphetXGBoost(database_manager=self.database_manager)
|
|
self.model_selector = ModelSelector()
|
|
self.product_categorizer = ProductCategorizer()
|
|
|
|
async def _get_repositories(self, session):
|
|
"""Initialize repositories with session"""
|
|
return {
|
|
'model': ModelRepository(session),
|
|
'training_log': TrainingLogRepository(session),
|
|
'performance': PerformanceRepository(session),
|
|
'artifact': ArtifactRepository(session)
|
|
}
|
|
|
|
async def train_tenant_models(self,
|
|
tenant_id: str,
|
|
training_dataset: TrainingDataSet,
|
|
job_id: Optional[str] = None,
|
|
session=None) -> Dict[str, Any]:
|
|
"""
|
|
Train models for all products using repository pattern with enhanced tracking.
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
training_dataset: Prepared training dataset with aligned dates
|
|
job_id: Training job identifier
|
|
|
|
Returns:
|
|
Dictionary with training results for each product
|
|
"""
|
|
if not job_id:
|
|
job_id = f"enhanced_ml_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
|
|
|
logger.info("Starting enhanced ML training pipeline",
|
|
job_id=job_id,
|
|
tenant_id=tenant_id)
|
|
|
|
try:
|
|
# Get database session and repositories
|
|
async with self.database_manager.get_session() as db_session:
|
|
repos = await self._get_repositories(db_session)
|
|
|
|
# Convert sales data to DataFrame
|
|
sales_df = pd.DataFrame(training_dataset.sales_data)
|
|
weather_df = pd.DataFrame(training_dataset.weather_data)
|
|
traffic_df = pd.DataFrame(training_dataset.traffic_data)
|
|
|
|
# Validate input data
|
|
await self._validate_input_data(sales_df, tenant_id)
|
|
|
|
# Get unique products from the sales data
|
|
products = sales_df['inventory_product_id'].unique().tolist()
|
|
|
|
# Debug: Log sales data details to understand why only one product is found
|
|
total_sales_records = len(sales_df)
|
|
sales_by_product = sales_df.groupby('inventory_product_id').size().to_dict()
|
|
|
|
logger.info("Enhanced training pipeline - Sales data analysis",
|
|
total_sales_records=total_sales_records,
|
|
products_count=len(products),
|
|
products=products,
|
|
sales_by_product=sales_by_product)
|
|
|
|
if len(products) == 1:
|
|
logger.warning("Only ONE product found in sales data - this may indicate a data fetching issue",
|
|
tenant_id=tenant_id,
|
|
single_product_id=products[0],
|
|
total_sales_records=total_sales_records)
|
|
elif len(products) == 0:
|
|
raise ValueError("No products found in sales data")
|
|
else:
|
|
logger.info("Multiple products detected for training",
|
|
products_count=len(products))
|
|
|
|
# Event 1: Training Started (0%) - update with actual product count AND time estimates
|
|
# Calculate accurate time estimates now that we know the actual product count
|
|
from app.utils.time_estimation import (
|
|
calculate_initial_estimate,
|
|
calculate_estimated_completion_time,
|
|
get_historical_average_estimate
|
|
)
|
|
|
|
# Try to get historical average for more accurate estimates
|
|
try:
|
|
historical_avg = await get_historical_average_estimate(
|
|
db_session,
|
|
tenant_id
|
|
)
|
|
avg_time_per_product = historical_avg if historical_avg else 60.0
|
|
logger.info("Using historical average for time estimation",
|
|
avg_time_per_product=avg_time_per_product,
|
|
has_historical_data=historical_avg is not None)
|
|
except Exception as e:
|
|
logger.warning("Could not get historical average, using default",
|
|
error=str(e))
|
|
avg_time_per_product = 60.0
|
|
|
|
estimated_duration_minutes = calculate_initial_estimate(
|
|
total_products=len(products),
|
|
avg_training_time_per_product=avg_time_per_product
|
|
)
|
|
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
|
|
|
|
# Note: Initial event was already published by API endpoint with estimated product count,
|
|
# this updates with real count and recalculated time estimates based on actual data
|
|
await publish_training_started(
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
total_products=len(products),
|
|
estimated_duration_minutes=estimated_duration_minutes,
|
|
estimated_completion_time=estimated_completion_time.isoformat()
|
|
)
|
|
|
|
# Create initial training log entry
|
|
await repos['training_log'].update_log_progress(
|
|
job_id, 5, "data_processing", "running"
|
|
)
|
|
|
|
# Process data for each product using enhanced processor
|
|
logger.info("Processing data using enhanced processor")
|
|
processed_data = await self._process_all_products_enhanced(
|
|
sales_df, weather_df, traffic_df, products, tenant_id, job_id
|
|
)
|
|
|
|
# Categorize all products for category-specific forecasting
|
|
logger.info("Categorizing products for optimized forecasting")
|
|
product_categories = await self._categorize_all_products(
|
|
sales_df, processed_data
|
|
)
|
|
logger.info("Product categorization complete",
|
|
total_products=len(product_categories),
|
|
categories_breakdown={cat.value: sum(1 for c in product_categories.values() if c == cat)
|
|
for cat in set(product_categories.values())})
|
|
|
|
# Event 2: Data Analysis (20%)
|
|
# Recalculate time remaining based on elapsed time
|
|
elapsed_seconds = (datetime.now(timezone.utc) - repos['training_log']._get_start_time(job_id) if hasattr(repos['training_log'], '_get_start_time') else 0) or 0
|
|
|
|
# Estimate remaining time: we've done ~20% of work (data analysis)
|
|
# Remaining 80% includes training all products
|
|
products_to_train = len(processed_data)
|
|
estimated_remaining_seconds = int(products_to_train * avg_time_per_product)
|
|
|
|
# Recalculate estimated completion time
|
|
estimated_completion_time_data_analysis = calculate_estimated_completion_time(
|
|
estimated_remaining_seconds / 60
|
|
)
|
|
|
|
await publish_data_analysis(
|
|
job_id,
|
|
tenant_id,
|
|
f"Data analysis completed for {len(processed_data)} products",
|
|
estimated_time_remaining_seconds=estimated_remaining_seconds,
|
|
estimated_completion_time=estimated_completion_time_data_analysis.isoformat()
|
|
)
|
|
|
|
# Train models for each processed product with progress aggregation
|
|
logger.info("Training models with repository integration and progress aggregation")
|
|
|
|
# Create progress tracker for parallel product training (20-80%)
|
|
progress_tracker = ParallelProductProgressTracker(
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
total_products=len(processed_data)
|
|
)
|
|
|
|
training_results = await self._train_all_models_enhanced(
|
|
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories
|
|
)
|
|
|
|
# Calculate overall training summary with enhanced metrics
|
|
summary = await self._calculate_enhanced_training_summary(
|
|
training_results, repos, tenant_id
|
|
)
|
|
|
|
# Calculate successful and failed trainings
|
|
successful_trainings = len([r for r in training_results.values() if r.get('status') == 'success'])
|
|
failed_trainings = len([r for r in training_results.values() if r.get('status') == 'error'])
|
|
total_duration = sum([r.get('training_time_seconds', 0) for r in training_results.values()])
|
|
|
|
# Event 4: Training Completed (100%)
|
|
await publish_training_completed(
|
|
job_id,
|
|
tenant_id,
|
|
successful_trainings,
|
|
failed_trainings,
|
|
total_duration
|
|
)
|
|
|
|
# Create comprehensive result with repository data
|
|
result = {
|
|
"job_id": job_id,
|
|
"tenant_id": tenant_id,
|
|
"status": "completed",
|
|
"products_trained": len([r for r in training_results.values() if r.get('status') == 'success']),
|
|
"products_failed": len([r for r in training_results.values() if r.get('status') == 'error']),
|
|
"products_skipped": len([r for r in training_results.values() if r.get('status') == 'skipped']),
|
|
"total_products": len(products),
|
|
"training_results": training_results,
|
|
"enhanced_summary": summary,
|
|
"models_trained": summary.get('models_created', {}),
|
|
"data_info": {
|
|
"date_range": {
|
|
"start": training_dataset.date_range.start.isoformat(),
|
|
"end": training_dataset.date_range.end.isoformat(),
|
|
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
|
|
},
|
|
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
|
|
"constraints_applied": training_dataset.date_range.constraints
|
|
},
|
|
"repository_metadata": {
|
|
"total_records_created": summary.get('total_db_records', 0),
|
|
"performance_metrics_stored": summary.get('performance_metrics_created', 0),
|
|
"artifacts_created": summary.get('artifacts_created', 0)
|
|
},
|
|
"completed_at": datetime.now().isoformat()
|
|
}
|
|
|
|
logger.info("Enhanced ML training pipeline completed successfully",
|
|
job_id=job_id,
|
|
models_created=len([r for r in training_results.values() if r.get('status') == 'success']))
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error("Enhanced ML training pipeline failed",
|
|
job_id=job_id,
|
|
error=str(e))
|
|
|
|
# Publish training failed event
|
|
await publish_training_failed(job_id, tenant_id, str(e))
|
|
|
|
raise
|
|
|
|
async def train_single_product_model(self,
|
|
tenant_id: str,
|
|
inventory_product_id: str,
|
|
training_data: pd.DataFrame,
|
|
job_id: str = None) -> Dict[str, Any]:
|
|
"""
|
|
Train a model for a single product using repository pattern.
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
inventory_product_id: Specific inventory product to train
|
|
training_data: Prepared training DataFrame for the product
|
|
job_id: Training job identifier (optional)
|
|
|
|
Returns:
|
|
Dictionary with model training results
|
|
"""
|
|
if not job_id:
|
|
job_id = f"single_product_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
|
|
|
|
logger.info("Starting single product model training",
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
data_points=len(training_data))
|
|
|
|
try:
|
|
# Get database session and repositories
|
|
async with self.database_manager.get_session() as db_session:
|
|
repos = await self._get_repositories(db_session)
|
|
|
|
# Validate input data
|
|
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
|
|
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
|
|
|
|
# Validate required columns
|
|
required_columns = ['ds', 'y']
|
|
missing_cols = [col for col in required_columns if col not in training_data.columns]
|
|
if missing_cols:
|
|
raise ValueError(f"Missing required columns in training data: {missing_cols}")
|
|
|
|
# Create a simple progress tracker for single product
|
|
from app.services.progress_tracker import ParallelProductProgressTracker
|
|
progress_tracker = ParallelProductProgressTracker(
|
|
job_id=job_id,
|
|
tenant_id=tenant_id,
|
|
total_products=1
|
|
)
|
|
|
|
# Ensure training data has proper data types before training
|
|
if 'ds' in training_data.columns:
|
|
training_data['ds'] = pd.to_datetime(training_data['ds'])
|
|
if 'y' in training_data.columns:
|
|
training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce')
|
|
|
|
# Remove any rows with NaN values
|
|
training_data = training_data.dropna()
|
|
|
|
# Train the model using the existing _train_single_product method
|
|
product_id, result = await self._train_single_product(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
product_data=training_data,
|
|
job_id=job_id,
|
|
repos=repos,
|
|
progress_tracker=progress_tracker
|
|
)
|
|
|
|
logger.info("Single product training completed",
|
|
job_id=job_id,
|
|
inventory_product_id=inventory_product_id,
|
|
result_status=result.get('status'))
|
|
|
|
# Get training metrics and filter out non-numeric values
|
|
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
|
|
# Filter metrics to only include numeric values (per Pydantic schema requirement)
|
|
filtered_metrics = {}
|
|
for key, value in raw_metrics.items():
|
|
if key == 'product_category':
|
|
# Skip product_category as it's a string value, not a numeric metric
|
|
continue
|
|
try:
|
|
# Try to convert to float for validation
|
|
filtered_metrics[key] = float(value) if value is not None else 0.0
|
|
except (ValueError, TypeError):
|
|
# Skip non-numeric values
|
|
continue
|
|
|
|
# Return appropriate result format
|
|
return {
|
|
"job_id": job_id,
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": inventory_product_id,
|
|
"status": result.get('status', 'success'),
|
|
"model_id": str(result.get('model_record_id', '')) if result.get('model_record_id') else None,
|
|
"training_metrics": filtered_metrics,
|
|
"training_time": result.get('training_time_seconds', 0),
|
|
"data_points": result.get('data_points', 0),
|
|
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Single product model training failed",
|
|
job_id=job_id,
|
|
inventory_product_id=inventory_product_id,
|
|
error=str(e))
|
|
raise
|
|
|
|
def _serialize_scalers(self, scalers: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Serialize scaler objects to basic Python types that can be stored in database.
|
|
This prevents issues with storing complex sklearn objects in JSON fields.
|
|
"""
|
|
if not scalers:
|
|
return {}
|
|
|
|
serialized = {}
|
|
for key, value in scalers.items():
|
|
try:
|
|
# Convert numpy scalars to Python native types
|
|
if hasattr(value, 'item'): # numpy scalars
|
|
serialized[key] = value.item()
|
|
elif isinstance(value, (np.integer, np.floating)):
|
|
serialized[key] = value.item() # Convert numpy types to Python types
|
|
elif isinstance(value, (int, float, str, bool, type(None))):
|
|
serialized[key] = value # Already basic type
|
|
elif isinstance(value, (list, tuple)):
|
|
# Convert list/tuple elements to basic types
|
|
serialized[key] = [v.item() if hasattr(v, 'item') else v for v in value]
|
|
else:
|
|
# For complex objects, try to convert to string representation
|
|
# or store as float if it's numeric
|
|
try:
|
|
serialized[key] = float(value)
|
|
except (ValueError, TypeError):
|
|
# If all else fails, convert to string
|
|
serialized[key] = str(value)
|
|
except Exception:
|
|
# If serialization fails, set to None to prevent database errors
|
|
serialized[key] = None
|
|
|
|
return serialized
|
|
|
|
async def _process_all_products_enhanced(self,
|
|
sales_df: pd.DataFrame,
|
|
weather_df: pd.DataFrame,
|
|
traffic_df: pd.DataFrame,
|
|
products: List[str],
|
|
tenant_id: str,
|
|
job_id: str) -> Dict[str, pd.DataFrame]:
|
|
"""Process data for all products using enhanced processor with repository tracking"""
|
|
processed_data = {}
|
|
|
|
for inventory_product_id in products:
|
|
try:
|
|
logger.info("Processing data for product using enhanced processor",
|
|
inventory_product_id=inventory_product_id)
|
|
|
|
# Filter sales data for this product
|
|
product_sales = sales_df[sales_df['inventory_product_id'] == inventory_product_id].copy()
|
|
|
|
if product_sales.empty:
|
|
logger.warning("No sales data found for product",
|
|
inventory_product_id=inventory_product_id)
|
|
continue
|
|
|
|
# Use enhanced data processor with repository tracking
|
|
processed_product_data = await self.enhanced_data_processor.prepare_training_data(
|
|
sales_data=product_sales,
|
|
weather_data=weather_df,
|
|
traffic_data=traffic_df,
|
|
inventory_product_id=inventory_product_id,
|
|
tenant_id=tenant_id,
|
|
job_id=job_id
|
|
)
|
|
|
|
processed_data[inventory_product_id] = processed_product_data
|
|
logger.info("Enhanced processing completed",
|
|
inventory_product_id=inventory_product_id,
|
|
data_points=len(processed_product_data))
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to process data using enhanced processor",
|
|
inventory_product_id=inventory_product_id,
|
|
error=str(e))
|
|
continue
|
|
|
|
return processed_data
|
|
|
|
async def _train_single_product(self,
|
|
tenant_id: str,
|
|
inventory_product_id: str,
|
|
product_data: pd.DataFrame,
|
|
job_id: str,
|
|
repos: Dict,
|
|
progress_tracker: ParallelProductProgressTracker,
|
|
product_category: ProductCategory = ProductCategory.UNKNOWN) -> tuple[str, Dict[str, Any]]:
|
|
"""Train a single product model - used for parallel execution with progress aggregation"""
|
|
product_start_time = time.time()
|
|
|
|
try:
|
|
logger.info("Training model",
|
|
inventory_product_id=inventory_product_id,
|
|
category=product_category.value)
|
|
|
|
# Check if we have enough data
|
|
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
|
|
result = {
|
|
'status': 'skipped',
|
|
'reason': 'insufficient_data',
|
|
'data_points': len(product_data),
|
|
'min_required': settings.MIN_TRAINING_DATA_DAYS,
|
|
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
|
|
}
|
|
logger.warning("Skipping product due to insufficient data",
|
|
inventory_product_id=inventory_product_id,
|
|
data_points=len(product_data),
|
|
min_required=settings.MIN_TRAINING_DATA_DAYS)
|
|
return inventory_product_id, result
|
|
|
|
# Get category-specific hyperparameters
|
|
category_characteristics = self.product_categorizer.get_category_characteristics(product_category)
|
|
|
|
# Determine which model type to use (Prophet vs Hybrid)
|
|
model_type = self.model_selector.select_model_type(
|
|
df=product_data,
|
|
product_category=product_category.value
|
|
)
|
|
|
|
logger.info("Model type selected",
|
|
inventory_product_id=inventory_product_id,
|
|
model_type=model_type,
|
|
category=product_category.value)
|
|
|
|
# Train the selected model
|
|
if model_type == "hybrid":
|
|
# Train hybrid Prophet + XGBoost model
|
|
model_info = await self.hybrid_trainer.train_hybrid_model(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
df=product_data,
|
|
job_id=job_id
|
|
)
|
|
model_info['model_type'] = 'hybrid_prophet_xgboost'
|
|
else:
|
|
# Train Prophet-only model with category-specific settings
|
|
model_info = await self.prophet_manager.train_bakery_model(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
df=product_data,
|
|
job_id=job_id,
|
|
product_category=product_category,
|
|
category_hyperparameters=category_characteristics.get('prophet_params', {})
|
|
)
|
|
model_info['model_type'] = 'prophet_optimized'
|
|
|
|
# Filter training metrics to exclude non-numeric values (e.g., product_category)
|
|
if 'training_metrics' in model_info and model_info['training_metrics']:
|
|
raw_metrics = model_info['training_metrics']
|
|
filtered_metrics = {}
|
|
for key, value in raw_metrics.items():
|
|
if key == 'product_category':
|
|
# Skip product_category as it's a string value, not a numeric metric
|
|
continue
|
|
try:
|
|
# Try to convert to float for validation
|
|
filtered_metrics[key] = float(value) if value is not None else 0.0
|
|
except (ValueError, TypeError):
|
|
# Skip non-numeric values
|
|
continue
|
|
model_info['training_metrics'] = filtered_metrics
|
|
|
|
# Store model record using repository
|
|
model_record = await self._create_model_record(
|
|
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
|
|
)
|
|
|
|
# Create performance metrics record
|
|
if model_info.get('training_metrics'):
|
|
await self._create_performance_metrics(
|
|
repos, model_record.id if model_record else None,
|
|
tenant_id, inventory_product_id, model_info['training_metrics']
|
|
)
|
|
|
|
result = {
|
|
'status': 'success',
|
|
'model_info': model_info,
|
|
'model_record_id': str(model_record.id) if model_record else None,
|
|
'data_points': len(product_data),
|
|
'training_time_seconds': time.time() - product_start_time,
|
|
'trained_at': datetime.now().isoformat()
|
|
}
|
|
|
|
logger.info("Successfully trained model",
|
|
inventory_product_id=inventory_product_id,
|
|
model_record_id=model_record.id if model_record else None)
|
|
|
|
# Report completion to progress tracker (emits Event 3: product_completed)
|
|
await progress_tracker.mark_product_completed(inventory_product_id)
|
|
|
|
return inventory_product_id, result
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to train model",
|
|
inventory_product_id=inventory_product_id,
|
|
error=str(e))
|
|
result = {
|
|
'status': 'error',
|
|
'error_message': str(e),
|
|
'data_points': len(product_data) if product_data is not None else 0,
|
|
'training_time_seconds': time.time() - product_start_time,
|
|
'failed_at': datetime.now().isoformat()
|
|
}
|
|
|
|
# Report failure to progress tracker (still emits Event 3: product_completed)
|
|
await progress_tracker.mark_product_completed(inventory_product_id)
|
|
|
|
return inventory_product_id, result
|
|
|
|
async def _train_all_models_enhanced(self,
|
|
tenant_id: str,
|
|
processed_data: Dict[str, pd.DataFrame],
|
|
job_id: str,
|
|
repos: Dict,
|
|
progress_tracker: ParallelProductProgressTracker,
|
|
product_categories: Dict[str, ProductCategory] = None) -> Dict[str, Any]:
|
|
"""Train models with throttled parallel execution and progress tracking"""
|
|
total_products = len(processed_data)
|
|
logger.info(f"Starting throttled parallel training for {total_products} products")
|
|
|
|
# Create training tasks for all products
|
|
training_tasks = [
|
|
self._train_single_product(
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id,
|
|
product_data=product_data,
|
|
job_id=job_id,
|
|
repos=repos,
|
|
progress_tracker=progress_tracker,
|
|
product_category=product_categories.get(inventory_product_id, ProductCategory.UNKNOWN) if product_categories else ProductCategory.UNKNOWN
|
|
)
|
|
for inventory_product_id, product_data in processed_data.items()
|
|
]
|
|
|
|
# Execute training tasks with throttling to prevent heartbeat blocking
|
|
# Limit concurrent operations to prevent CPU/memory exhaustion
|
|
from app.core.config import settings
|
|
max_concurrent = getattr(settings, 'MAX_CONCURRENT_TRAININGS', 3)
|
|
|
|
logger.info(f"Executing training with max {max_concurrent} concurrent operations",
|
|
total_products=total_products)
|
|
|
|
# Process tasks in batches to prevent blocking the event loop
|
|
results_list = []
|
|
for i in range(0, len(training_tasks), max_concurrent):
|
|
batch = training_tasks[i:i + max_concurrent]
|
|
batch_results = await asyncio.gather(*batch, return_exceptions=True)
|
|
results_list.extend(batch_results)
|
|
|
|
# Yield control to event loop to allow heartbeat processing
|
|
# Increased from 0.01s to 0.1s (100ms) to ensure WebSocket pings, RabbitMQ heartbeats,
|
|
# and progress events can be processed during long training operations
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Log progress to verify event loop is responsive
|
|
logger.debug(
|
|
"Training batch completed, yielding to event loop",
|
|
batch_num=(i // max_concurrent) + 1,
|
|
total_batches=(len(training_tasks) + max_concurrent - 1) // max_concurrent,
|
|
products_completed=len(results_list),
|
|
total_products=len(training_tasks)
|
|
)
|
|
|
|
# Log final summary
|
|
summary = progress_tracker.get_progress()
|
|
logger.info("Throttled parallel training completed",
|
|
total=summary['total_products'],
|
|
completed=summary['products_completed'])
|
|
|
|
# Convert results to dictionary
|
|
training_results = {}
|
|
for result in results_list:
|
|
if isinstance(result, Exception):
|
|
logger.error(f"Training task failed with exception: {result}")
|
|
continue
|
|
|
|
product_id, product_result = result
|
|
training_results[product_id] = product_result
|
|
|
|
logger.info(f"Throttled parallel training completed: {len(training_results)} products processed")
|
|
return training_results
|
|
|
|
async def _create_model_record(self,
|
|
repos: Dict,
|
|
tenant_id: str,
|
|
inventory_product_id: str,
|
|
model_info: Dict,
|
|
job_id: str,
|
|
processed_data: pd.DataFrame):
|
|
"""Create model record using repository"""
|
|
try:
|
|
# Extract training period from the processed data
|
|
training_start_date = None
|
|
training_end_date = None
|
|
if 'ds' in processed_data.columns and not processed_data.empty:
|
|
# Ensure ds column is datetime64 before extracting dates (prevents object dtype issues)
|
|
ds_datetime = pd.to_datetime(processed_data['ds'])
|
|
|
|
# Get min/max as pandas Timestamps (guaranteed to work correctly)
|
|
min_ts = ds_datetime.min()
|
|
max_ts = ds_datetime.max()
|
|
|
|
# Convert to python datetime with timezone removal
|
|
if pd.notna(min_ts):
|
|
training_start_date = pd.Timestamp(min_ts).to_pydatetime().replace(tzinfo=None)
|
|
if pd.notna(max_ts):
|
|
training_end_date = pd.Timestamp(max_ts).to_pydatetime().replace(tzinfo=None)
|
|
|
|
# Ensure features are clean string list
|
|
try:
|
|
features_used = [str(col) for col in processed_data.columns]
|
|
except Exception:
|
|
features_used = []
|
|
|
|
model_data = {
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": inventory_product_id,
|
|
"job_id": job_id,
|
|
"model_type": "enhanced_prophet",
|
|
"model_path": model_info.get("model_path"),
|
|
"metadata_path": model_info.get("metadata_path"),
|
|
"mape": float(model_info.get("training_metrics", {}).get("mape", 0)) if model_info.get("training_metrics", {}).get("mape") is not None else 0,
|
|
"mae": float(model_info.get("training_metrics", {}).get("mae", 0)) if model_info.get("training_metrics", {}).get("mae") is not None else 0,
|
|
"rmse": float(model_info.get("training_metrics", {}).get("rmse", 0)) if model_info.get("training_metrics", {}).get("rmse") is not None else 0,
|
|
"r2_score": float(model_info.get("training_metrics", {}).get("r2", 0)) if model_info.get("training_metrics", {}).get("r2") is not None else 0,
|
|
"training_samples": int(len(processed_data)),
|
|
"hyperparameters": self._serialize_scalers(model_info.get("hyperparameters", {})),
|
|
"features_used": [str(f) for f in features_used] if features_used else [],
|
|
"normalization_params": self._serialize_scalers(self.enhanced_data_processor.get_scalers()) or {}, # Include scalers for prediction consistency
|
|
"product_category": model_info.get("product_category", "unknown"), # Store product category
|
|
"is_active": True,
|
|
"is_production": True,
|
|
"data_quality_score": float(model_info.get("data_quality_score", 100.0)) if model_info.get("data_quality_score") is not None else 100.0,
|
|
"training_start_date": training_start_date,
|
|
"training_end_date": training_end_date
|
|
}
|
|
|
|
model_record = await repos['model'].create_model(model_data)
|
|
logger.info("Created enhanced model record",
|
|
inventory_product_id=inventory_product_id,
|
|
model_id=model_record.id)
|
|
|
|
# Create artifacts for model files
|
|
if model_info.get("model_path"):
|
|
await repos['artifact'].create_artifact({
|
|
"model_id": str(model_record.id),
|
|
"tenant_id": tenant_id,
|
|
"artifact_type": "enhanced_model_file",
|
|
"file_path": model_info["model_path"],
|
|
"storage_location": "local"
|
|
})
|
|
|
|
return model_record
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to create enhanced model record",
|
|
inventory_product_id=inventory_product_id,
|
|
error=str(e))
|
|
return None
|
|
|
|
async def _create_performance_metrics(self,
|
|
repos: Dict,
|
|
model_id: str,
|
|
tenant_id: str,
|
|
inventory_product_id: str,
|
|
metrics: Dict):
|
|
"""Create performance metrics record using repository"""
|
|
try:
|
|
metric_data = {
|
|
"model_id": str(model_id),
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": inventory_product_id,
|
|
"mae": float(metrics.get("mae")) if metrics.get("mae") is not None else None,
|
|
"mse": float(metrics.get("mse")) if metrics.get("mse") is not None else None,
|
|
"rmse": float(metrics.get("rmse")) if metrics.get("rmse") is not None else None,
|
|
"mape": float(metrics.get("mape")) if metrics.get("mape") is not None else None,
|
|
"r2_score": float(metrics.get("r2")) if metrics.get("r2") is not None else None,
|
|
"accuracy_percentage": float(100 - metrics.get("mape", 0)) if metrics.get("mape") is not None else None,
|
|
"evaluation_samples": int(metrics.get("data_points", 0)) if metrics.get("data_points") is not None else 0
|
|
}
|
|
|
|
await repos['performance'].create_performance_metric(metric_data)
|
|
logger.info("Created enhanced performance metrics",
|
|
inventory_product_id=inventory_product_id,
|
|
model_id=model_id)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to create enhanced performance metrics",
|
|
inventory_product_id=inventory_product_id,
|
|
error=str(e))
|
|
|
|
async def _calculate_enhanced_training_summary(self,
|
|
training_results: Dict[str, Any],
|
|
repos: Dict,
|
|
tenant_id: str) -> Dict[str, Any]:
|
|
"""Calculate enhanced summary statistics with repository data"""
|
|
total_products = len(training_results)
|
|
successful_products = len([r for r in training_results.values() if r.get('status') == 'success'])
|
|
failed_products = len([r for r in training_results.values() if r.get('status') == 'error'])
|
|
skipped_products = len([r for r in training_results.values() if r.get('status') == 'skipped'])
|
|
|
|
# Calculate average training metrics for successful models
|
|
successful_results = [r for r in training_results.values() if r.get('status') == 'success']
|
|
|
|
avg_metrics = {}
|
|
if successful_results:
|
|
metrics_list = [r['model_info'].get('training_metrics', {}) for r in successful_results]
|
|
|
|
if metrics_list and all(metrics_list):
|
|
avg_metrics = {
|
|
'avg_mae': round(np.mean([m.get('mae', 0) for m in metrics_list]), 2),
|
|
'avg_rmse': round(np.mean([m.get('rmse', 0) for m in metrics_list]), 2),
|
|
'avg_mape': round(np.mean([m.get('mape', 0) for m in metrics_list]), 2),
|
|
'avg_r2': round(np.mean([m.get('r2', 0) for m in metrics_list]), 3),
|
|
'avg_training_time': round(np.mean([r.get('training_time_seconds', 0) for r in successful_results]), 2)
|
|
}
|
|
|
|
# Calculate data quality insights
|
|
data_points_list = [r.get('data_points', 0) for r in training_results.values()]
|
|
|
|
# Get database statistics
|
|
try:
|
|
# Get tenant model count from repository
|
|
tenant_models = await repos['model'].get_models_by_tenant(tenant_id)
|
|
models_created = [r.get('model_record_id') for r in successful_results if r.get('model_record_id')]
|
|
|
|
db_stats = {
|
|
'total_tenant_models': len(tenant_models),
|
|
'models_created_this_job': len(models_created),
|
|
'total_db_records': len(models_created),
|
|
'performance_metrics_created': len(models_created), # One per model
|
|
'artifacts_created': len([r for r in successful_results if r.get('model_info', {}).get('model_path')])
|
|
}
|
|
except Exception as e:
|
|
logger.warning("Failed to get database statistics", error=str(e))
|
|
db_stats = {
|
|
'total_tenant_models': 0,
|
|
'models_created_this_job': 0,
|
|
'total_db_records': 0,
|
|
'performance_metrics_created': 0,
|
|
'artifacts_created': 0
|
|
}
|
|
|
|
# Build models_created with proper model result structure
|
|
models_created = {}
|
|
for product, result in training_results.items():
|
|
if result.get('status') == 'success' and result.get('model_info'):
|
|
model_info = result['model_info']
|
|
models_created[product] = {
|
|
'status': 'completed',
|
|
'model_path': model_info.get('model_path'),
|
|
'metadata_path': model_info.get('metadata_path'),
|
|
'metrics': model_info.get('training_metrics', {}),
|
|
'hyperparameters': model_info.get('hyperparameters', {}),
|
|
'features_used': model_info.get('features_used', []),
|
|
'data_points': result.get('data_points', 0),
|
|
'data_quality_score': model_info.get('data_quality_score', 100.0),
|
|
'model_record_id': result.get('model_record_id')
|
|
}
|
|
|
|
enhanced_summary = {
|
|
'total_products': total_products,
|
|
'successful_products': successful_products,
|
|
'failed_products': failed_products,
|
|
'skipped_products': skipped_products,
|
|
'success_rate': round(successful_products / total_products * 100, 2) if total_products > 0 else 0,
|
|
'enhanced_average_metrics': avg_metrics,
|
|
'enhanced_data_summary': {
|
|
'total_data_points': sum(data_points_list),
|
|
'avg_data_points_per_product': round(np.mean(data_points_list), 1) if data_points_list else 0,
|
|
'min_data_points': min(data_points_list) if data_points_list else 0,
|
|
'max_data_points': max(data_points_list) if data_points_list else 0
|
|
},
|
|
'database_statistics': db_stats,
|
|
'models_created': models_created
|
|
}
|
|
|
|
# Add database statistics to the summary
|
|
enhanced_summary.update(db_stats)
|
|
|
|
return enhanced_summary
|
|
|
|
async def _validate_input_data(self, sales_df: pd.DataFrame, tenant_id: str):
|
|
"""Validate input sales data with enhanced error reporting"""
|
|
if sales_df.empty:
|
|
raise ValueError(f"No sales data provided for tenant {tenant_id}")
|
|
|
|
# Handle quantity column mapping
|
|
if 'quantity_sold' in sales_df.columns and 'quantity' not in sales_df.columns:
|
|
sales_df['quantity'] = sales_df['quantity_sold']
|
|
logger.info("Mapped quantity column",
|
|
from_column='quantity_sold',
|
|
to_column='quantity')
|
|
|
|
required_columns = ['date', 'inventory_product_id', 'quantity']
|
|
missing_columns = [col for col in required_columns if col not in sales_df.columns]
|
|
if missing_columns:
|
|
raise ValueError(f"Missing required columns: {missing_columns}")
|
|
|
|
# Check for valid dates
|
|
try:
|
|
sales_df['date'] = pd.to_datetime(sales_df['date'])
|
|
except Exception:
|
|
raise ValueError("Invalid date format in sales data")
|
|
|
|
# Check for valid quantities
|
|
if not sales_df['quantity'].dtype in ['int64', 'float64']:
|
|
try:
|
|
sales_df['quantity'] = pd.to_numeric(sales_df['quantity'], errors='coerce')
|
|
except Exception:
|
|
raise ValueError("Quantity column must be numeric")
|
|
|
|
async def _categorize_all_products(
|
|
self,
|
|
sales_df: pd.DataFrame,
|
|
processed_data: Dict[str, pd.DataFrame]
|
|
) -> Dict[str, ProductCategory]:
|
|
"""
|
|
Categorize all products for category-specific forecasting.
|
|
|
|
Args:
|
|
sales_df: Raw sales data with product names
|
|
processed_data: Processed data by product ID
|
|
|
|
Returns:
|
|
Dict mapping inventory_product_id to ProductCategory
|
|
"""
|
|
product_categories = {}
|
|
|
|
for inventory_product_id in processed_data.keys():
|
|
try:
|
|
# Get product name from sales data (if available)
|
|
product_sales = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
|
|
|
|
# Extract product name (try multiple possible column names)
|
|
product_name = "unknown"
|
|
for name_col in ['product_name', 'name', 'item_name']:
|
|
if name_col in product_sales.columns and not product_sales[name_col].empty:
|
|
product_name = product_sales[name_col].iloc[0]
|
|
break
|
|
|
|
# Prepare sales data for pattern analysis
|
|
sales_for_analysis = product_sales[['date', 'quantity']].copy() if 'date' in product_sales.columns else None
|
|
|
|
# Categorize product
|
|
category = self.product_categorizer.categorize_product(
|
|
product_name=str(product_name),
|
|
product_id=inventory_product_id,
|
|
sales_data=sales_for_analysis
|
|
)
|
|
|
|
product_categories[inventory_product_id] = category
|
|
|
|
logger.debug("Product categorized",
|
|
inventory_product_id=inventory_product_id,
|
|
product_name=product_name,
|
|
category=category.value)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to categorize product {inventory_product_id}: {e}")
|
|
product_categories[inventory_product_id] = ProductCategory.UNKNOWN
|
|
|
|
return product_categories
|
|
|
|
async def evaluate_model_performance_enhanced(self,
|
|
tenant_id: str,
|
|
inventory_product_id: str,
|
|
model_path: str,
|
|
test_dataset: TrainingDataSet) -> Dict[str, Any]:
|
|
"""
|
|
Enhanced model evaluation with repository integration.
|
|
"""
|
|
try:
|
|
logger.info("Enhanced model evaluation starting",
|
|
tenant_id=tenant_id,
|
|
inventory_product_id=inventory_product_id)
|
|
|
|
# Get database session and repositories
|
|
async with self.database_manager.get_session() as db_session:
|
|
repos = await self._get_repositories(db_session)
|
|
|
|
# Convert test data to DataFrames
|
|
test_sales_df = pd.DataFrame(test_dataset.sales_data)
|
|
test_weather_df = pd.DataFrame(test_dataset.weather_data)
|
|
test_traffic_df = pd.DataFrame(test_dataset.traffic_data)
|
|
|
|
# Filter for specific product
|
|
product_test_sales = test_sales_df[test_sales_df['inventory_product_id'] == inventory_product_id].copy()
|
|
|
|
if product_test_sales.empty:
|
|
raise ValueError(f"No test data found for product: {inventory_product_id}")
|
|
|
|
# Process test data using enhanced processor
|
|
processed_test_data = await self.enhanced_data_processor.prepare_training_data(
|
|
sales_data=product_test_sales,
|
|
weather_data=test_weather_df,
|
|
traffic_data=test_traffic_df,
|
|
inventory_product_id=inventory_product_id,
|
|
tenant_id=tenant_id
|
|
)
|
|
|
|
# Create future dataframe for prediction
|
|
future_dates = processed_test_data[['ds']].copy()
|
|
|
|
# Add regressor columns
|
|
regressor_columns = [col for col in processed_test_data.columns if col not in ['ds', 'y']]
|
|
for col in regressor_columns:
|
|
future_dates[col] = processed_test_data[col]
|
|
|
|
# Generate predictions
|
|
forecast = await self.prophet_manager.generate_forecast(
|
|
model_path=model_path,
|
|
future_dates=future_dates,
|
|
regressor_columns=regressor_columns
|
|
)
|
|
|
|
# Calculate performance metrics
|
|
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
|
|
|
y_true = processed_test_data['y'].values
|
|
y_pred = forecast['yhat'].values
|
|
|
|
# Ensure arrays are the same length
|
|
min_len = min(len(y_true), len(y_pred))
|
|
y_true = y_true[:min_len]
|
|
y_pred = y_pred[:min_len]
|
|
|
|
metrics = {
|
|
"mae": float(mean_absolute_error(y_true, y_pred)),
|
|
"rmse": float(np.sqrt(mean_squared_error(y_true, y_pred))),
|
|
"r2_score": float(r2_score(y_true, y_pred))
|
|
}
|
|
|
|
# Calculate MAPE safely
|
|
non_zero_mask = y_true > 0.1
|
|
if np.sum(non_zero_mask) > 0:
|
|
mape = np.mean(np.abs((y_true[non_zero_mask] - y_pred[non_zero_mask]) / y_true[non_zero_mask])) * 100
|
|
metrics["mape"] = float(min(mape, 200)) # Cap at 200%
|
|
else:
|
|
metrics["mape"] = 100.0
|
|
|
|
# Store evaluation metrics in repository
|
|
model_records = await repos['model'].get_models_by_product(tenant_id, inventory_product_id)
|
|
if model_records:
|
|
latest_model = max(model_records, key=lambda x: x.created_at)
|
|
await self._create_performance_metrics(
|
|
repos, latest_model.id, tenant_id, inventory_product_id, metrics
|
|
)
|
|
|
|
result = {
|
|
"tenant_id": tenant_id,
|
|
"inventory_product_id": inventory_product_id,
|
|
"enhanced_evaluation_metrics": metrics,
|
|
"test_samples": len(processed_test_data),
|
|
"prediction_samples": len(forecast),
|
|
"test_period": {
|
|
"start": test_dataset.date_range.start.isoformat(),
|
|
"end": test_dataset.date_range.end.isoformat()
|
|
},
|
|
"evaluated_at": datetime.now().isoformat(),
|
|
"repository_integration": {
|
|
"metrics_stored": True,
|
|
"model_record_found": len(model_records) > 0 if model_records else False
|
|
}
|
|
}
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error("Enhanced model evaluation failed", error=str(e))
|
|
raise
|