Files
bakery-ia/services/training/app/ml/trainer.py

1376 lines
66 KiB
Python
Raw Normal View History

"""
2025-08-08 09:08:41 +02:00
Enhanced ML Trainer with Repository Pattern
Main ML pipeline coordinator using repository pattern for data access and dependency injection
"""
2025-07-28 19:28:39 +02:00
from typing import Dict, List, Any, Optional
import pandas as pd
import numpy as np
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
from datetime import datetime, timezone
2025-08-08 09:08:41 +02:00
import structlog
2025-07-19 16:59:37 +02:00
import uuid
2025-08-04 18:21:42 +02:00
import time
import asyncio
2025-08-08 09:08:41 +02:00
from app.ml.data_processor import EnhancedBakeryDataProcessor
2025-07-28 19:28:39 +02:00
from app.ml.prophet_manager import BakeryProphetManager
2025-11-05 13:34:56 +01:00
from app.ml.product_categorizer import ProductCategorizer, ProductCategory
from app.ml.model_selector import ModelSelector
from app.ml.hybrid_trainer import HybridProphetXGBoost
2025-07-28 19:28:39 +02:00
from app.services.training_orchestrator import TrainingDataSet
from app.core.config import settings
2025-08-08 09:08:41 +02:00
from shared.database.base import create_database_manager
from shared.database.transactions import transactional
from shared.database.unit_of_work import UnitOfWork
from shared.database.exceptions import DatabaseError
from app.repositories import (
ModelRepository,
TrainingLogRepository,
PerformanceRepository,
ArtifactRepository
)
2025-07-28 19:28:39 +02:00
from app.services.progress_tracker import ParallelProductProgressTracker
from app.services.training_events import (
publish_training_started,
publish_data_analysis,
publish_training_completed,
publish_training_failed
)
2025-08-08 09:08:41 +02:00
logger = structlog.get_logger()
2025-08-08 09:08:41 +02:00
class EnhancedBakeryMLTrainer:
2025-07-19 16:59:37 +02:00
"""
2025-08-08 09:08:41 +02:00
Enhanced ML trainer using repository pattern for data access and comprehensive tracking.
Orchestrates the complete ML training pipeline with proper database abstraction.
2025-07-19 16:59:37 +02:00
"""
2025-08-08 09:08:41 +02:00
def __init__(self, database_manager=None):
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "training-service")
self.enhanced_data_processor = EnhancedBakeryDataProcessor(self.database_manager)
self.prophet_manager = BakeryProphetManager(database_manager=self.database_manager)
2025-11-05 13:34:56 +01:00
self.hybrid_trainer = HybridProphetXGBoost(database_manager=self.database_manager)
self.model_selector = ModelSelector()
self.product_categorizer = ProductCategorizer()
2025-08-08 09:08:41 +02:00
async def _get_repositories(self, session):
"""Initialize repositories with session"""
return {
'model': ModelRepository(session),
'training_log': TrainingLogRepository(session),
'performance': PerformanceRepository(session),
'artifact': ArtifactRepository(session)
}
2025-07-19 16:59:37 +02:00
async def train_tenant_models(self,
tenant_id: str,
2025-07-28 19:28:39 +02:00
training_dataset: TrainingDataSet,
2025-08-08 09:08:41 +02:00
job_id: Optional[str] = None,
session=None) -> Dict[str, Any]:
2025-07-19 16:59:37 +02:00
"""
2025-08-08 09:08:41 +02:00
Train models for all products using repository pattern with enhanced tracking.
2025-07-19 16:59:37 +02:00
Args:
tenant_id: Tenant identifier
2025-07-28 19:28:39 +02:00
training_dataset: Prepared training dataset with aligned dates
2025-07-19 16:59:37 +02:00
job_id: Training job identifier
2025-11-05 18:47:20 +01:00
session: Database session to use (if None, creates one)
2025-07-19 16:59:37 +02:00
Returns:
Dictionary with training results for each product
"""
if not job_id:
2025-08-08 09:08:41 +02:00
job_id = f"enhanced_ml_{tenant_id}_{uuid.uuid4().hex[:8]}"
2025-07-19 16:59:37 +02:00
2025-08-08 09:08:41 +02:00
logger.info("Starting enhanced ML training pipeline",
job_id=job_id,
tenant_id=tenant_id)
2025-07-19 16:59:37 +02:00
try:
2025-11-05 18:47:20 +01:00
# Use provided session or create new one to prevent nested sessions and deadlocks
should_create_session = session is None
db_session = session if session is not None else None
# Use the provided session or create a new one if needed
if should_create_session:
async with self.database_manager.get_session() as db_session:
return await self._execute_training_pipeline(
tenant_id, training_dataset, job_id, db_session
2025-10-19 19:22:37 +02:00
)
2025-11-05 18:47:20 +01:00
else:
# Use the provided session (no context manager needed since caller manages it)
return await self._execute_training_pipeline(
tenant_id, training_dataset, job_id, session
2025-10-19 19:22:37 +02:00
)
2025-07-19 16:59:37 +02:00
except Exception as e:
2025-08-08 09:08:41 +02:00
logger.error("Enhanced ML training pipeline failed",
job_id=job_id,
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
error=str(e),
exc_info=True)
# Publish training failed event
await publish_training_failed(job_id, tenant_id, str(e))
2025-07-19 16:59:37 +02:00
raise
2025-11-05 18:47:20 +01:00
async def _execute_training_pipeline(self, tenant_id: str, training_dataset: TrainingDataSet,
job_id: str, session) -> Dict[str, Any]:
"""
Execute the training pipeline with the given session.
This is extracted to avoid code duplication when handling provided vs. created sessions.
"""
# Get repositories with the session
repos = await self._get_repositories(session)
# Convert sales data to DataFrame
sales_df = pd.DataFrame(training_dataset.sales_data)
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
# Validate input data
await self._validate_input_data(sales_df, tenant_id)
# Get unique products from the sales data
products = sales_df['inventory_product_id'].unique().tolist()
# Debug: Log sales data details to understand why only one product is found
total_sales_records = len(sales_df)
sales_by_product = sales_df.groupby('inventory_product_id').size().to_dict()
logger.info("Enhanced training pipeline - Sales data analysis",
total_sales_records=total_sales_records,
products_count=len(products),
products=products,
sales_by_product=sales_by_product)
if len(products) == 1:
logger.warning("Only ONE product found in sales data - this may indicate a data fetching issue",
tenant_id=tenant_id,
single_product_id=products[0],
total_sales_records=total_sales_records)
elif len(products) == 0:
raise ValueError("No products found in sales data")
else:
logger.info("Multiple products detected for training",
products_count=len(products))
# Event 1: Training Started (0%) - update with actual product count AND time estimates
# Calculate accurate time estimates now that we know the actual product count
from app.utils.time_estimation import (
calculate_initial_estimate,
calculate_estimated_completion_time,
get_historical_average_estimate
)
# Try to get historical average for more accurate estimates
try:
historical_avg = await get_historical_average_estimate(
session,
tenant_id
)
avg_time_per_product = historical_avg if historical_avg else 60.0
logger.info("Using historical average for time estimation",
avg_time_per_product=avg_time_per_product,
has_historical_data=historical_avg is not None)
except Exception as e:
logger.warning("Could not get historical average, using default",
error=str(e))
avg_time_per_product = 60.0
estimated_duration_minutes = calculate_initial_estimate(
total_products=len(products),
avg_training_time_per_product=avg_time_per_product
)
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
# Note: Initial event was already published by API endpoint with estimated product count,
# this updates with real count and recalculated time estimates based on actual data
await publish_training_started(
job_id=job_id,
tenant_id=tenant_id,
total_products=len(products),
estimated_duration_minutes=estimated_duration_minutes,
estimated_completion_time=estimated_completion_time.isoformat()
)
# Create initial training log entry
await repos['training_log'].update_log_progress(
job_id, 5, "data_processing", "running"
)
# ✅ FIX: Flush the session to ensure the update is committed before proceeding
# This prevents deadlocks when training methods need to acquire locks
await session.flush()
logger.debug("Flushed session after initial progress update")
# Process data for each product using enhanced processor
logger.info("Processing data using enhanced processor")
processed_data = await self._process_all_products_enhanced(
sales_df, weather_df, traffic_df, products, tenant_id, job_id, training_dataset.poi_features, session
2025-11-05 18:47:20 +01:00
)
# Validate that we have processed data
if not processed_data or len(processed_data) == 0:
error_msg = f"No products could be processed successfully. Found {len(products)} products in sales data but all failed during processing."
logger.error("Training aborted - no processed data",
tenant_id=tenant_id,
job_id=job_id,
products_found=len(products),
products_processed=0)
raise ValueError(error_msg)
logger.info(f"Successfully processed {len(processed_data)} out of {len(products)} products",
products_processed=len(processed_data),
products_found=len(products))
2025-11-05 18:47:20 +01:00
# Categorize all products for category-specific forecasting
logger.info("Categorizing products for optimized forecasting")
product_categories = await self._categorize_all_products(
sales_df, processed_data
)
logger.info("Product categorization complete",
total_products=len(product_categories),
categories_breakdown={cat.value: sum(1 for c in product_categories.values() if c == cat)
for cat in set(product_categories.values())})
# Event 2: Data Analysis (20%)
# Recalculate time remaining based on elapsed time
start_time = await repos['training_log'].get_start_time(job_id)
elapsed_seconds = 0
if start_time:
elapsed_seconds = int((datetime.now(timezone.utc) - start_time).total_seconds())
# Estimate remaining time: we've done ~20% of work (data analysis)
# Remaining 80% includes training all products
products_to_train = len(processed_data)
estimated_remaining_seconds = int(products_to_train * avg_time_per_product)
# Recalculate estimated completion time
estimated_completion_time_data_analysis = calculate_estimated_completion_time(
estimated_remaining_seconds / 60
)
await publish_data_analysis(
job_id,
tenant_id,
f"Data analysis completed for {len(processed_data)} products",
estimated_time_remaining_seconds=estimated_remaining_seconds,
estimated_completion_time=estimated_completion_time_data_analysis.isoformat()
)
# Train models for each processed product with progress aggregation
logger.info("Training models with repository integration and progress aggregation")
# Create progress tracker for parallel product training (20-80%)
progress_tracker = ParallelProductProgressTracker(
job_id=job_id,
tenant_id=tenant_id,
total_products=len(processed_data)
)
# Train all models in parallel (without DB writes to avoid session conflicts)
# ✅ FIX: Pass session to prevent nested session issues and deadlocks
training_results = await self._train_all_models_enhanced(
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories, session
)
# Write all training results to database sequentially (after parallel training completes)
logger.info("Writing training results to database sequentially")
training_results = await self._write_training_results_to_database(
tenant_id, job_id, training_results, repos
)
2025-11-05 22:54:14 +01:00
# ✅ CRITICAL FIX: Commit the session to persist model records to database
# Without this commit, all model records created above are lost when session closes
await session.commit()
logger.info("Committed model records to database",
models_created=len([r for r in training_results.values() if 'model_record_id' in r]))
2025-11-05 18:47:20 +01:00
# Calculate overall training summary with enhanced metrics
summary = await self._calculate_enhanced_training_summary(
training_results, repos, tenant_id
)
# Calculate successful and failed trainings
successful_trainings = len([r for r in training_results.values() if r.get('status') == 'success'])
failed_trainings = len([r for r in training_results.values() if r.get('status') == 'error'])
total_duration = sum([r.get('training_time_seconds', 0) for r in training_results.values()])
# Event 4: Training Completed (100%)
await publish_training_completed(
job_id,
tenant_id,
successful_trainings,
failed_trainings,
total_duration
)
# Create comprehensive result with repository data
result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "completed",
"products_trained": len([r for r in training_results.values() if r.get('status') == 'success']),
"products_failed": len([r for r in training_results.values() if r.get('status') == 'error']),
"products_skipped": len([r for r in training_results.values() if r.get('status') == 'skipped']),
"total_products": len(products),
"training_results": training_results,
"enhanced_summary": summary,
"models_trained": summary.get('models_created', {}),
"data_info": {
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat(),
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
},
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
},
"repository_metadata": {
"total_records_created": summary.get('total_db_records', 0),
"performance_metrics_stored": summary.get('performance_metrics_created', 0),
"artifacts_created": summary.get('artifacts_created', 0)
},
"completed_at": datetime.now().isoformat()
}
logger.info("Enhanced ML training pipeline completed successfully",
job_id=job_id,
models_created=len([r for r in training_results.values() if r.get('status') == 'success']))
return result
2025-07-19 16:59:37 +02:00
2025-11-05 13:34:56 +01:00
async def train_single_product_model(self,
tenant_id: str,
inventory_product_id: str,
training_data: pd.DataFrame,
2025-11-05 18:47:20 +01:00
job_id: str = None,
session=None) -> Dict[str, Any]:
2025-11-05 13:34:56 +01:00
"""
Train a model for a single product using repository pattern.
Args:
tenant_id: Tenant identifier
inventory_product_id: Specific inventory product to train
training_data: Prepared training DataFrame for the product
job_id: Training job identifier (optional)
2025-11-05 18:47:20 +01:00
session: Database session to use (if None, creates one)
2025-11-05 13:34:56 +01:00
Returns:
Dictionary with model training results
"""
if not job_id:
job_id = f"single_product_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
logger.info("Starting single product model training",
job_id=job_id,
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
data_points=len(training_data))
try:
2025-11-05 18:47:20 +01:00
# Use provided session or create new one to prevent nested sessions and deadlocks
should_create_session = session is None
2025-11-14 20:27:39 +01:00
2025-11-05 18:47:20 +01:00
if should_create_session:
# Only create a session if one wasn't provided
async with self.database_manager.get_session() as db_session:
repos = await self._get_repositories(db_session)
2025-11-14 20:27:39 +01:00
# Validate input data
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
# Validate required columns
required_columns = ['ds', 'y']
missing_cols = [col for col in required_columns if col not in training_data.columns]
if missing_cols:
raise ValueError(f"Missing required columns in training data: {missing_cols}")
# Create a simple progress tracker for single product
from app.services.progress_tracker import ParallelProductProgressTracker
progress_tracker = ParallelProductProgressTracker(
job_id=job_id,
tenant_id=tenant_id,
total_products=1
)
# Ensure training data has proper data types before training
if 'ds' in training_data.columns:
training_data['ds'] = pd.to_datetime(training_data['ds'])
if 'y' in training_data.columns:
training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce')
# Remove any rows with NaN values
training_data = training_data.dropna()
# Train the model using the existing _train_single_product method
product_id, result = await self._train_single_product(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
product_data=training_data,
job_id=job_id,
repos=repos,
progress_tracker=progress_tracker,
session=db_session # Pass the session to prevent nested sessions
)
logger.info("Single product training completed",
job_id=job_id,
inventory_product_id=inventory_product_id,
result_status=result.get('status'))
# Write training result to database (create model record)
if result.get('status') == 'success':
model_info = result.get('model_info')
product_data = result.get('product_data')
if model_info and product_data is not None:
# Create model record in database
model_record = await self._create_model_record(
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
)
# Create performance metrics
if model_info.get('training_metrics') and model_record:
await self._create_performance_metrics(
repos, model_record.id,
tenant_id, inventory_product_id, model_info['training_metrics']
)
# Update result with model_record_id
result['model_record_id'] = str(model_record.id) if model_record else None
# Get training metrics and filter out non-numeric values
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
# Filter metrics to only include numeric values (per Pydantic schema requirement)
filtered_metrics = {}
for key, value in raw_metrics.items():
if key == 'product_category':
# Skip product_category as it's a string value, not a numeric metric
continue
try:
# Try to convert to float for validation
filtered_metrics[key] = float(value) if value is not None else 0.0
except (ValueError, TypeError):
# Skip non-numeric values
continue
# Return appropriate result format
result_dict = {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"status": result.get('status', 'success'),
"model_id": str(result.get('model_record_id', '')) if result.get('model_record_id') else None,
"training_metrics": filtered_metrics,
"training_time": result.get('training_time_seconds', 0),
"data_points": result.get('data_points', 0),
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
}
# Only commit if this is our own session (not a parent session)
# Commit after we're done with all database operations
await db_session.commit()
logger.info("Committed single product model record to database",
inventory_product_id=inventory_product_id,
model_record_id=result.get('model_record_id'))
return result_dict
else:
# Use the provided session
repos = await self._get_repositories(session)
2025-11-05 13:34:56 +01:00
# Validate input data
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
2025-11-14 20:27:39 +01:00
2025-11-05 13:34:56 +01:00
# Validate required columns
required_columns = ['ds', 'y']
missing_cols = [col for col in required_columns if col not in training_data.columns]
if missing_cols:
raise ValueError(f"Missing required columns in training data: {missing_cols}")
2025-11-14 20:27:39 +01:00
2025-11-05 13:34:56 +01:00
# Create a simple progress tracker for single product
from app.services.progress_tracker import ParallelProductProgressTracker
progress_tracker = ParallelProductProgressTracker(
2025-11-14 20:27:39 +01:00
job_id=job_id,
tenant_id=tenant_id,
2025-11-05 13:34:56 +01:00
total_products=1
)
2025-11-14 20:27:39 +01:00
2025-11-05 13:34:56 +01:00
# Ensure training data has proper data types before training
if 'ds' in training_data.columns:
training_data['ds'] = pd.to_datetime(training_data['ds'])
if 'y' in training_data.columns:
training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce')
2025-11-14 20:27:39 +01:00
2025-11-05 13:34:56 +01:00
# Remove any rows with NaN values
training_data = training_data.dropna()
2025-11-14 20:27:39 +01:00
2025-11-05 13:34:56 +01:00
# Train the model using the existing _train_single_product method
product_id, result = await self._train_single_product(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
product_data=training_data,
job_id=job_id,
repos=repos,
2025-11-05 18:47:20 +01:00
progress_tracker=progress_tracker,
2025-11-14 20:27:39 +01:00
session=session # Pass the provided session
2025-11-05 13:34:56 +01:00
)
2025-11-14 20:27:39 +01:00
2025-11-05 13:34:56 +01:00
logger.info("Single product training completed",
job_id=job_id,
inventory_product_id=inventory_product_id,
result_status=result.get('status'))
2025-11-14 20:27:39 +01:00
# Write training result to database (create model record)
if result.get('status') == 'success':
model_info = result.get('model_info')
product_data = result.get('product_data')
if model_info and product_data is not None:
# Create model record in database
model_record = await self._create_model_record(
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
)
# Create performance metrics
if model_info.get('training_metrics') and model_record:
await self._create_performance_metrics(
repos, model_record.id,
tenant_id, inventory_product_id, model_info['training_metrics']
)
# Update result with model_record_id
result['model_record_id'] = str(model_record.id) if model_record else None
2025-11-05 13:34:56 +01:00
# Get training metrics and filter out non-numeric values
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
# Filter metrics to only include numeric values (per Pydantic schema requirement)
filtered_metrics = {}
for key, value in raw_metrics.items():
if key == 'product_category':
# Skip product_category as it's a string value, not a numeric metric
continue
try:
# Try to convert to float for validation
filtered_metrics[key] = float(value) if value is not None else 0.0
except (ValueError, TypeError):
# Skip non-numeric values
continue
2025-11-14 20:27:39 +01:00
2025-11-05 13:34:56 +01:00
# Return appropriate result format
2025-11-05 18:47:20 +01:00
result_dict = {
2025-11-05 13:34:56 +01:00
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"status": result.get('status', 'success'),
"model_id": str(result.get('model_record_id', '')) if result.get('model_record_id') else None,
"training_metrics": filtered_metrics,
"training_time": result.get('training_time_seconds', 0),
"data_points": result.get('data_points', 0),
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
}
2025-11-14 20:27:39 +01:00
# For provided sessions, do NOT commit here - let the calling method handle commits
# This prevents committing a parent transaction prematurely
logger.info("Single product model processed (commit handled by caller)",
inventory_product_id=inventory_product_id,
model_record_id=result.get('model_record_id'))
2025-11-05 18:47:20 +01:00
return result_dict
2025-11-05 13:34:56 +01:00
except Exception as e:
logger.error("Single product model training failed",
job_id=job_id,
inventory_product_id=inventory_product_id,
Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
error=str(e),
exc_info=True)
2025-11-05 13:34:56 +01:00
raise
def _serialize_scalers(self, scalers: Dict[str, Any]) -> Dict[str, Any]:
"""
Serialize scaler objects to basic Python types that can be stored in database.
This prevents issues with storing complex sklearn objects in JSON fields.
"""
if not scalers:
return {}
serialized = {}
for key, value in scalers.items():
try:
# Convert numpy scalars to Python native types
if hasattr(value, 'item'): # numpy scalars
serialized[key] = value.item()
elif isinstance(value, (np.integer, np.floating)):
serialized[key] = value.item() # Convert numpy types to Python types
elif isinstance(value, (int, float, str, bool, type(None))):
serialized[key] = value # Already basic type
elif isinstance(value, (list, tuple)):
# Convert list/tuple elements to basic types
serialized[key] = [v.item() if hasattr(v, 'item') else v for v in value]
else:
# For complex objects, try to convert to string representation
# or store as float if it's numeric
try:
serialized[key] = float(value)
except (ValueError, TypeError):
# If all else fails, convert to string
serialized[key] = str(value)
except Exception:
# If serialization fails, set to None to prevent database errors
serialized[key] = None
return serialized
2025-08-08 09:08:41 +02:00
async def _process_all_products_enhanced(self,
sales_df: pd.DataFrame,
weather_df: pd.DataFrame,
traffic_df: pd.DataFrame,
products: List[str],
tenant_id: str,
2025-11-05 18:47:20 +01:00
job_id: str,
poi_features: Dict[str, Any],
2025-11-05 18:47:20 +01:00
session=None) -> Dict[str, pd.DataFrame]:
2025-08-08 09:08:41 +02:00
"""Process data for all products using enhanced processor with repository tracking"""
2025-07-19 16:59:37 +02:00
processed_data = {}
2025-08-14 16:47:34 +02:00
for inventory_product_id in products:
2025-07-19 16:59:37 +02:00
try:
2025-08-08 09:08:41 +02:00
logger.info("Processing data for product using enhanced processor",
2025-08-14 16:47:34 +02:00
inventory_product_id=inventory_product_id)
2025-07-19 16:59:37 +02:00
# Filter sales data for this product
2025-08-14 16:47:34 +02:00
product_sales = sales_df[sales_df['inventory_product_id'] == inventory_product_id].copy()
2025-07-19 16:59:37 +02:00
2025-07-28 19:28:39 +02:00
if product_sales.empty:
2025-08-08 09:08:41 +02:00
logger.warning("No sales data found for product",
2025-08-14 16:47:34 +02:00
inventory_product_id=inventory_product_id)
2025-07-28 19:28:39 +02:00
continue
2025-08-08 09:08:41 +02:00
# Use enhanced data processor with repository tracking
2025-11-05 18:47:20 +01:00
# Pass the session to prevent nested session issues
2025-08-08 09:08:41 +02:00
processed_product_data = await self.enhanced_data_processor.prepare_training_data(
2025-07-19 16:59:37 +02:00
sales_data=product_sales,
weather_data=weather_df,
traffic_data=traffic_df,
2025-08-14 16:47:34 +02:00
inventory_product_id=inventory_product_id,
poi_features=poi_features, # POI features for location-based forecasting
2025-08-08 09:08:41 +02:00
tenant_id=tenant_id,
2025-11-05 18:47:20 +01:00
job_id=job_id,
session=session # Pass the session to avoid creating new ones
2025-07-19 16:59:37 +02:00
)
2025-08-14 16:47:34 +02:00
processed_data[inventory_product_id] = processed_product_data
2025-08-08 09:08:41 +02:00
logger.info("Enhanced processing completed",
2025-08-14 16:47:34 +02:00
inventory_product_id=inventory_product_id,
2025-08-08 09:08:41 +02:00
data_points=len(processed_product_data))
2025-07-19 16:59:37 +02:00
except Exception as e:
2025-08-08 09:08:41 +02:00
logger.error("Failed to process data using enhanced processor",
2025-08-14 16:47:34 +02:00
inventory_product_id=inventory_product_id,
2025-08-08 09:08:41 +02:00
error=str(e))
2025-07-19 16:59:37 +02:00
continue
2025-07-19 16:59:37 +02:00
return processed_data
async def _train_single_product(self,
tenant_id: str,
inventory_product_id: str,
product_data: pd.DataFrame,
job_id: str,
repos: Dict,
2025-11-05 13:34:56 +01:00
progress_tracker: ParallelProductProgressTracker,
Fix training hang caused by nested database sessions and deadlocks Root Cause: The training process was hanging at the first progress update due to a nested database session issue. The main trainer created a session and repositories, then called prophet_manager.train_bakery_model() which created another nested session with an advisory lock. This caused a deadlock where: 1. Outer session had uncommitted UPDATE on model_training_logs 2. Inner session tried to acquire advisory lock 3. Neither could proceed, causing training to hang indefinitely Changes Made: 1. prophet_manager.py: - Added optional 'session' parameter to train_bakery_model() - Refactored to use parent session if provided, otherwise create new one - Prevents nested session creation during training 2. hybrid_trainer.py: - Added optional 'session' parameter to train_hybrid_model() - Passes session to prophet_manager to maintain single session context 3. trainer.py: - Updated _train_single_product() to accept and pass session - Updated _train_all_models_enhanced() to accept and pass session - Pass db_session from main training context to all training methods - Added explicit db_session.flush() after critical progress update - This ensures updates are visible before acquiring locks Impact: - Eliminates nested session deadlocks - Training now proceeds past initial progress update - Maintains single database session context throughout training - Prevents database transaction conflicts Related Issues: - Fixes training hang during onboarding process - Not directly related to audit_metadata changes but exposed by them 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-05 16:13:32 +01:00
product_category: ProductCategory = ProductCategory.UNKNOWN,
session = None) -> tuple[str, Dict[str, Any]]:
"""
Train a single product model - used for parallel execution with progress aggregation.
Note: This method ONLY trains the model and collects results. Database writes happen
separately to avoid concurrent session conflicts.
Fix training hang caused by nested database sessions and deadlocks Root Cause: The training process was hanging at the first progress update due to a nested database session issue. The main trainer created a session and repositories, then called prophet_manager.train_bakery_model() which created another nested session with an advisory lock. This caused a deadlock where: 1. Outer session had uncommitted UPDATE on model_training_logs 2. Inner session tried to acquire advisory lock 3. Neither could proceed, causing training to hang indefinitely Changes Made: 1. prophet_manager.py: - Added optional 'session' parameter to train_bakery_model() - Refactored to use parent session if provided, otherwise create new one - Prevents nested session creation during training 2. hybrid_trainer.py: - Added optional 'session' parameter to train_hybrid_model() - Passes session to prophet_manager to maintain single session context 3. trainer.py: - Updated _train_single_product() to accept and pass session - Updated _train_all_models_enhanced() to accept and pass session - Pass db_session from main training context to all training methods - Added explicit db_session.flush() after critical progress update - This ensures updates are visible before acquiring locks Impact: - Eliminates nested session deadlocks - Training now proceeds past initial progress update - Maintains single database session context throughout training - Prevents database transaction conflicts Related Issues: - Fixes training hang during onboarding process - Not directly related to audit_metadata changes but exposed by them 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-05 16:13:32 +01:00
Args:
session: Database session to use for training (prevents nested session issues)
"""
product_start_time = time.time()
try:
2025-11-05 13:34:56 +01:00
logger.info("Training model",
inventory_product_id=inventory_product_id,
category=product_category.value)
# Check if we have enough data
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
result = {
'status': 'skipped',
'reason': 'insufficient_data',
'data_points': len(product_data),
'min_required': settings.MIN_TRAINING_DATA_DAYS,
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}',
'product_data': product_data, # Store for later DB writes
'product_category': product_category
}
logger.warning("Skipping product due to insufficient data",
inventory_product_id=inventory_product_id,
data_points=len(product_data),
min_required=settings.MIN_TRAINING_DATA_DAYS)
return inventory_product_id, result
2025-11-05 13:34:56 +01:00
# Get category-specific hyperparameters
category_characteristics = self.product_categorizer.get_category_characteristics(product_category)
# Determine which model type to use (Prophet vs Hybrid)
model_type = self.model_selector.select_model_type(
df=product_data,
2025-11-05 13:34:56 +01:00
product_category=product_category.value
)
2025-11-05 13:34:56 +01:00
logger.info("Model type selected",
inventory_product_id=inventory_product_id,
model_type=model_type,
category=product_category.value)
# Train the selected model
Fix training hang caused by nested database sessions and deadlocks Root Cause: The training process was hanging at the first progress update due to a nested database session issue. The main trainer created a session and repositories, then called prophet_manager.train_bakery_model() which created another nested session with an advisory lock. This caused a deadlock where: 1. Outer session had uncommitted UPDATE on model_training_logs 2. Inner session tried to acquire advisory lock 3. Neither could proceed, causing training to hang indefinitely Changes Made: 1. prophet_manager.py: - Added optional 'session' parameter to train_bakery_model() - Refactored to use parent session if provided, otherwise create new one - Prevents nested session creation during training 2. hybrid_trainer.py: - Added optional 'session' parameter to train_hybrid_model() - Passes session to prophet_manager to maintain single session context 3. trainer.py: - Updated _train_single_product() to accept and pass session - Updated _train_all_models_enhanced() to accept and pass session - Pass db_session from main training context to all training methods - Added explicit db_session.flush() after critical progress update - This ensures updates are visible before acquiring locks Impact: - Eliminates nested session deadlocks - Training now proceeds past initial progress update - Maintains single database session context throughout training - Prevents database transaction conflicts Related Issues: - Fixes training hang during onboarding process - Not directly related to audit_metadata changes but exposed by them 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-05 16:13:32 +01:00
# ✅ FIX: Pass session to training methods to avoid nested session issues
2025-11-05 13:34:56 +01:00
if model_type == "hybrid":
# Train hybrid Prophet + XGBoost model
model_info = await self.hybrid_trainer.train_hybrid_model(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
df=product_data,
Fix training hang caused by nested database sessions and deadlocks Root Cause: The training process was hanging at the first progress update due to a nested database session issue. The main trainer created a session and repositories, then called prophet_manager.train_bakery_model() which created another nested session with an advisory lock. This caused a deadlock where: 1. Outer session had uncommitted UPDATE on model_training_logs 2. Inner session tried to acquire advisory lock 3. Neither could proceed, causing training to hang indefinitely Changes Made: 1. prophet_manager.py: - Added optional 'session' parameter to train_bakery_model() - Refactored to use parent session if provided, otherwise create new one - Prevents nested session creation during training 2. hybrid_trainer.py: - Added optional 'session' parameter to train_hybrid_model() - Passes session to prophet_manager to maintain single session context 3. trainer.py: - Updated _train_single_product() to accept and pass session - Updated _train_all_models_enhanced() to accept and pass session - Pass db_session from main training context to all training methods - Added explicit db_session.flush() after critical progress update - This ensures updates are visible before acquiring locks Impact: - Eliminates nested session deadlocks - Training now proceeds past initial progress update - Maintains single database session context throughout training - Prevents database transaction conflicts Related Issues: - Fixes training hang during onboarding process - Not directly related to audit_metadata changes but exposed by them 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-05 16:13:32 +01:00
job_id=job_id,
session=session
2025-11-05 13:34:56 +01:00
)
model_info['model_type'] = 'hybrid_prophet_xgboost'
else:
# Train Prophet-only model with category-specific settings
model_info = await self.prophet_manager.train_bakery_model(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
df=product_data,
job_id=job_id,
product_category=product_category,
Fix training hang caused by nested database sessions and deadlocks Root Cause: The training process was hanging at the first progress update due to a nested database session issue. The main trainer created a session and repositories, then called prophet_manager.train_bakery_model() which created another nested session with an advisory lock. This caused a deadlock where: 1. Outer session had uncommitted UPDATE on model_training_logs 2. Inner session tried to acquire advisory lock 3. Neither could proceed, causing training to hang indefinitely Changes Made: 1. prophet_manager.py: - Added optional 'session' parameter to train_bakery_model() - Refactored to use parent session if provided, otherwise create new one - Prevents nested session creation during training 2. hybrid_trainer.py: - Added optional 'session' parameter to train_hybrid_model() - Passes session to prophet_manager to maintain single session context 3. trainer.py: - Updated _train_single_product() to accept and pass session - Updated _train_all_models_enhanced() to accept and pass session - Pass db_session from main training context to all training methods - Added explicit db_session.flush() after critical progress update - This ensures updates are visible before acquiring locks Impact: - Eliminates nested session deadlocks - Training now proceeds past initial progress update - Maintains single database session context throughout training - Prevents database transaction conflicts Related Issues: - Fixes training hang during onboarding process - Not directly related to audit_metadata changes but exposed by them 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-05 16:13:32 +01:00
category_hyperparameters=category_characteristics.get('prophet_params', {}),
session=session
2025-11-05 13:34:56 +01:00
)
model_info['model_type'] = 'prophet_optimized'
# Filter training metrics to exclude non-numeric values (e.g., product_category)
if 'training_metrics' in model_info and model_info['training_metrics']:
raw_metrics = model_info['training_metrics']
filtered_metrics = {}
for key, value in raw_metrics.items():
if key == 'product_category':
# Skip product_category as it's a string value, not a numeric metric
continue
try:
# Try to convert to float for validation
filtered_metrics[key] = float(value) if value is not None else 0.0
except (ValueError, TypeError):
# Skip non-numeric values
continue
model_info['training_metrics'] = filtered_metrics
# IMPORTANT: Do NOT write to database here - causes concurrent session conflicts
# Store all info needed for later DB writes (done sequentially after all training completes)
result = {
'status': 'success',
'model_info': model_info,
'data_points': len(product_data),
'training_time_seconds': time.time() - product_start_time,
'trained_at': datetime.now().isoformat(),
# Store data needed for DB writes later
'product_data': product_data,
'product_category': product_category
}
logger.info("Successfully trained model (DB writes deferred)",
inventory_product_id=inventory_product_id)
# Report completion to progress tracker (emits Event 3: product_completed)
await progress_tracker.mark_product_completed(inventory_product_id)
return inventory_product_id, result
except Exception as e:
logger.error("Failed to train model",
inventory_product_id=inventory_product_id,
error=str(e))
result = {
'status': 'error',
'error_message': str(e),
'data_points': len(product_data) if product_data is not None else 0,
'training_time_seconds': time.time() - product_start_time,
'failed_at': datetime.now().isoformat()
}
# Report failure to progress tracker (still emits Event 3: product_completed)
await progress_tracker.mark_product_completed(inventory_product_id)
return inventory_product_id, result
2025-08-08 09:08:41 +02:00
async def _train_all_models_enhanced(self,
tenant_id: str,
processed_data: Dict[str, pd.DataFrame],
job_id: str,
repos: Dict,
2025-11-05 13:34:56 +01:00
progress_tracker: ParallelProductProgressTracker,
Fix training hang caused by nested database sessions and deadlocks Root Cause: The training process was hanging at the first progress update due to a nested database session issue. The main trainer created a session and repositories, then called prophet_manager.train_bakery_model() which created another nested session with an advisory lock. This caused a deadlock where: 1. Outer session had uncommitted UPDATE on model_training_logs 2. Inner session tried to acquire advisory lock 3. Neither could proceed, causing training to hang indefinitely Changes Made: 1. prophet_manager.py: - Added optional 'session' parameter to train_bakery_model() - Refactored to use parent session if provided, otherwise create new one - Prevents nested session creation during training 2. hybrid_trainer.py: - Added optional 'session' parameter to train_hybrid_model() - Passes session to prophet_manager to maintain single session context 3. trainer.py: - Updated _train_single_product() to accept and pass session - Updated _train_all_models_enhanced() to accept and pass session - Pass db_session from main training context to all training methods - Added explicit db_session.flush() after critical progress update - This ensures updates are visible before acquiring locks Impact: - Eliminates nested session deadlocks - Training now proceeds past initial progress update - Maintains single database session context throughout training - Prevents database transaction conflicts Related Issues: - Fixes training hang during onboarding process - Not directly related to audit_metadata changes but exposed by them 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-05 16:13:32 +01:00
product_categories: Dict[str, ProductCategory] = None,
session = None) -> Dict[str, Any]:
"""
Train models with throttled parallel execution and progress tracking
Args:
session: Database session to pass to training methods (prevents nested session issues)
"""
2025-08-04 18:21:42 +02:00
total_products = len(processed_data)
logger.info(f"Starting throttled parallel training for {total_products} products")
# Create training tasks for all products
Fix training hang caused by nested database sessions and deadlocks Root Cause: The training process was hanging at the first progress update due to a nested database session issue. The main trainer created a session and repositories, then called prophet_manager.train_bakery_model() which created another nested session with an advisory lock. This caused a deadlock where: 1. Outer session had uncommitted UPDATE on model_training_logs 2. Inner session tried to acquire advisory lock 3. Neither could proceed, causing training to hang indefinitely Changes Made: 1. prophet_manager.py: - Added optional 'session' parameter to train_bakery_model() - Refactored to use parent session if provided, otherwise create new one - Prevents nested session creation during training 2. hybrid_trainer.py: - Added optional 'session' parameter to train_hybrid_model() - Passes session to prophet_manager to maintain single session context 3. trainer.py: - Updated _train_single_product() to accept and pass session - Updated _train_all_models_enhanced() to accept and pass session - Pass db_session from main training context to all training methods - Added explicit db_session.flush() after critical progress update - This ensures updates are visible before acquiring locks Impact: - Eliminates nested session deadlocks - Training now proceeds past initial progress update - Maintains single database session context throughout training - Prevents database transaction conflicts Related Issues: - Fixes training hang during onboarding process - Not directly related to audit_metadata changes but exposed by them 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-05 16:13:32 +01:00
# ✅ FIX: Pass session to prevent nested session issues and deadlocks
training_tasks = [
self._train_single_product(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
product_data=product_data,
job_id=job_id,
repos=repos,
2025-11-05 13:34:56 +01:00
progress_tracker=progress_tracker,
Fix training hang caused by nested database sessions and deadlocks Root Cause: The training process was hanging at the first progress update due to a nested database session issue. The main trainer created a session and repositories, then called prophet_manager.train_bakery_model() which created another nested session with an advisory lock. This caused a deadlock where: 1. Outer session had uncommitted UPDATE on model_training_logs 2. Inner session tried to acquire advisory lock 3. Neither could proceed, causing training to hang indefinitely Changes Made: 1. prophet_manager.py: - Added optional 'session' parameter to train_bakery_model() - Refactored to use parent session if provided, otherwise create new one - Prevents nested session creation during training 2. hybrid_trainer.py: - Added optional 'session' parameter to train_hybrid_model() - Passes session to prophet_manager to maintain single session context 3. trainer.py: - Updated _train_single_product() to accept and pass session - Updated _train_all_models_enhanced() to accept and pass session - Pass db_session from main training context to all training methods - Added explicit db_session.flush() after critical progress update - This ensures updates are visible before acquiring locks Impact: - Eliminates nested session deadlocks - Training now proceeds past initial progress update - Maintains single database session context throughout training - Prevents database transaction conflicts Related Issues: - Fixes training hang during onboarding process - Not directly related to audit_metadata changes but exposed by them 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-05 16:13:32 +01:00
product_category=product_categories.get(inventory_product_id, ProductCategory.UNKNOWN) if product_categories else ProductCategory.UNKNOWN,
session=session
)
for inventory_product_id, product_data in processed_data.items()
]
# Execute training tasks with throttling to prevent heartbeat blocking
# Limit concurrent operations to prevent CPU/memory exhaustion
from app.core.config import settings
max_concurrent = getattr(settings, 'MAX_CONCURRENT_TRAININGS', 3)
2025-08-04 18:21:42 +02:00
logger.info(f"Executing training with max {max_concurrent} concurrent operations",
total_products=total_products)
# Process tasks in batches to prevent blocking the event loop
results_list = []
for i in range(0, len(training_tasks), max_concurrent):
batch = training_tasks[i:i + max_concurrent]
batch_results = await asyncio.gather(*batch, return_exceptions=True)
results_list.extend(batch_results)
# Yield control to event loop to allow heartbeat processing
# Increased from 0.01s to 0.1s (100ms) to ensure WebSocket pings, RabbitMQ heartbeats,
# and progress events can be processed during long training operations
await asyncio.sleep(0.1)
# Log progress to verify event loop is responsive
logger.debug(
"Training batch completed, yielding to event loop",
batch_num=(i // max_concurrent) + 1,
total_batches=(len(training_tasks) + max_concurrent - 1) // max_concurrent,
products_completed=len(results_list),
total_products=len(training_tasks)
)
# Log final summary
summary = progress_tracker.get_progress()
logger.info("Throttled parallel training completed",
total=summary['total_products'],
completed=summary['products_completed'])
# Convert results to dictionary
training_results = {}
for result in results_list:
if isinstance(result, Exception):
logger.error(f"Training task failed with exception: {result}")
continue
product_id, product_result = result
training_results[product_id] = product_result
logger.info(f"Throttled parallel training completed: {len(training_results)} products processed")
2025-07-19 16:59:37 +02:00
return training_results
async def _write_training_results_to_database(self,
tenant_id: str,
job_id: str,
training_results: Dict[str, Any],
repos: Dict) -> Dict[str, Any]:
"""
Write training results to database sequentially to avoid concurrent session conflicts.
This method is called AFTER all parallel training is complete.
"""
logger.info("Writing training results to database sequentially",
total_products=len(training_results))
updated_results = {}
for product_id, result in training_results.items():
try:
if result.get('status') == 'success':
model_info = result.get('model_info')
product_data = result.get('product_data')
if model_info and product_data is not None:
# Create model record
model_record = await self._create_model_record(
repos, tenant_id, product_id, model_info, job_id, product_data
)
# Create performance metrics
if model_info.get('training_metrics') and model_record:
await self._create_performance_metrics(
repos, model_record.id,
tenant_id, product_id, model_info['training_metrics']
)
# Update result with model_record_id
result['model_record_id'] = str(model_record.id) if model_record else None
logger.info("Database records created successfully",
inventory_product_id=product_id,
model_record_id=model_record.id if model_record else None)
# Remove product_data from result to avoid serialization issues
if 'product_data' in result:
del result['product_data']
if 'product_category' in result:
del result['product_category']
updated_results[product_id] = result
except Exception as e:
logger.error("Failed to write database records for product",
inventory_product_id=product_id,
error=str(e))
# Keep the training result but mark that DB write failed
result['db_write_error'] = str(e)
if 'product_data' in result:
del result['product_data']
if 'product_category' in result:
del result['product_category']
updated_results[product_id] = result
logger.info("Database writes completed",
successful_writes=len([r for r in updated_results.values() if 'model_record_id' in r]),
total_products=len(updated_results))
return updated_results
2025-08-08 09:08:41 +02:00
async def _create_model_record(self,
repos: Dict,
tenant_id: str,
2025-08-14 16:47:34 +02:00
inventory_product_id: str,
2025-08-08 09:08:41 +02:00
model_info: Dict,
job_id: str,
processed_data: pd.DataFrame):
"""Create model record using repository"""
try:
2025-11-05 13:34:56 +01:00
# Extract training period from the processed data
training_start_date = None
training_end_date = None
2025-11-14 07:23:56 +01:00
data_freshness_days = None
data_coverage_days = None
2025-11-05 13:34:56 +01:00
if 'ds' in processed_data.columns and not processed_data.empty:
# Ensure ds column is datetime64 before extracting dates (prevents object dtype issues)
ds_datetime = pd.to_datetime(processed_data['ds'])
# Get min/max as pandas Timestamps (guaranteed to work correctly)
min_ts = ds_datetime.min()
max_ts = ds_datetime.max()
# Convert to python datetime with timezone removal
if pd.notna(min_ts):
training_start_date = pd.Timestamp(min_ts).to_pydatetime().replace(tzinfo=None)
if pd.notna(max_ts):
training_end_date = pd.Timestamp(max_ts).to_pydatetime().replace(tzinfo=None)
2025-11-14 07:23:56 +01:00
# Calculate data freshness metrics
if training_end_date:
from datetime import datetime
data_freshness_days = (datetime.now() - training_end_date).days
# Calculate data coverage period
if training_start_date and training_end_date:
data_coverage_days = (training_end_date - training_start_date).days
2025-11-05 13:34:56 +01:00
# Ensure features are clean string list
try:
features_used = [str(col) for col in processed_data.columns]
except Exception:
features_used = []
2025-11-14 07:23:56 +01:00
# Prepare hyperparameters with data freshness metrics
hyperparameters = model_info.get("hyperparameters", {})
if data_freshness_days is not None:
hyperparameters["data_freshness_days"] = data_freshness_days
if data_coverage_days is not None:
hyperparameters["data_coverage_days"] = data_coverage_days
2025-08-08 09:08:41 +02:00
model_data = {
"tenant_id": tenant_id,
2025-08-14 16:47:34 +02:00
"inventory_product_id": inventory_product_id,
2025-08-08 09:08:41 +02:00
"job_id": job_id,
"model_type": "enhanced_prophet",
"model_path": model_info.get("model_path"),
"metadata_path": model_info.get("metadata_path"),
2025-11-05 13:34:56 +01:00
"mape": float(model_info.get("training_metrics", {}).get("mape", 0)) if model_info.get("training_metrics", {}).get("mape") is not None else 0,
"mae": float(model_info.get("training_metrics", {}).get("mae", 0)) if model_info.get("training_metrics", {}).get("mae") is not None else 0,
"rmse": float(model_info.get("training_metrics", {}).get("rmse", 0)) if model_info.get("training_metrics", {}).get("rmse") is not None else 0,
"r2_score": float(model_info.get("training_metrics", {}).get("r2", 0)) if model_info.get("training_metrics", {}).get("r2") is not None else 0,
"training_samples": int(len(processed_data)),
2025-11-14 07:23:56 +01:00
"hyperparameters": self._serialize_scalers(hyperparameters),
2025-11-05 13:34:56 +01:00
"features_used": [str(f) for f in features_used] if features_used else [],
"normalization_params": self._serialize_scalers(self.enhanced_data_processor.get_scalers()) or {}, # Include scalers for prediction consistency
"product_category": model_info.get("product_category", "unknown"), # Store product category
2025-08-08 09:08:41 +02:00
"is_active": True,
"is_production": True,
2025-11-05 13:34:56 +01:00
"data_quality_score": float(model_info.get("data_quality_score", 100.0)) if model_info.get("data_quality_score") is not None else 100.0,
"training_start_date": training_start_date,
"training_end_date": training_end_date
2025-08-08 09:08:41 +02:00
}
model_record = await repos['model'].create_model(model_data)
logger.info("Created enhanced model record",
2025-08-14 16:47:34 +02:00
inventory_product_id=inventory_product_id,
2025-11-14 07:23:56 +01:00
model_id=model_record.id,
data_freshness_days=data_freshness_days,
data_coverage_days=data_coverage_days)
2025-08-08 09:08:41 +02:00
# Create artifacts for model files
if model_info.get("model_path"):
await repos['artifact'].create_artifact({
"model_id": str(model_record.id),
"tenant_id": tenant_id,
"artifact_type": "enhanced_model_file",
"file_path": model_info["model_path"],
"storage_location": "local"
})
return model_record
except Exception as e:
logger.error("Failed to create enhanced model record",
2025-08-14 16:47:34 +02:00
inventory_product_id=inventory_product_id,
2025-08-08 09:08:41 +02:00
error=str(e))
return None
async def _create_performance_metrics(self,
repos: Dict,
model_id: str,
tenant_id: str,
2025-08-14 16:47:34 +02:00
inventory_product_id: str,
2025-08-08 09:08:41 +02:00
metrics: Dict):
"""Create performance metrics record using repository"""
try:
metric_data = {
"model_id": str(model_id),
"tenant_id": tenant_id,
2025-08-14 16:47:34 +02:00
"inventory_product_id": inventory_product_id,
2025-11-05 13:34:56 +01:00
"mae": float(metrics.get("mae")) if metrics.get("mae") is not None else None,
"mse": float(metrics.get("mse")) if metrics.get("mse") is not None else None,
"rmse": float(metrics.get("rmse")) if metrics.get("rmse") is not None else None,
"mape": float(metrics.get("mape")) if metrics.get("mape") is not None else None,
"r2_score": float(metrics.get("r2")) if metrics.get("r2") is not None else None,
"accuracy_percentage": float(100 - metrics.get("mape", 0)) if metrics.get("mape") is not None else None,
"evaluation_samples": int(metrics.get("data_points", 0)) if metrics.get("data_points") is not None else 0
2025-08-08 09:08:41 +02:00
}
await repos['performance'].create_performance_metric(metric_data)
logger.info("Created enhanced performance metrics",
2025-08-14 16:47:34 +02:00
inventory_product_id=inventory_product_id,
2025-08-08 09:08:41 +02:00
model_id=model_id)
except Exception as e:
logger.error("Failed to create enhanced performance metrics",
2025-08-14 16:47:34 +02:00
inventory_product_id=inventory_product_id,
2025-08-08 09:08:41 +02:00
error=str(e))
async def _calculate_enhanced_training_summary(self,
training_results: Dict[str, Any],
repos: Dict,
tenant_id: str) -> Dict[str, Any]:
"""Calculate enhanced summary statistics with repository data"""
2025-07-19 16:59:37 +02:00
total_products = len(training_results)
successful_products = len([r for r in training_results.values() if r.get('status') == 'success'])
failed_products = len([r for r in training_results.values() if r.get('status') == 'error'])
skipped_products = len([r for r in training_results.values() if r.get('status') == 'skipped'])
# Calculate average training metrics for successful models
successful_results = [r for r in training_results.values() if r.get('status') == 'success']
avg_metrics = {}
if successful_results:
metrics_list = [r['model_info'].get('training_metrics', {}) for r in successful_results]
if metrics_list and all(metrics_list):
avg_metrics = {
2025-07-28 19:28:39 +02:00
'avg_mae': round(np.mean([m.get('mae', 0) for m in metrics_list]), 2),
'avg_rmse': round(np.mean([m.get('rmse', 0) for m in metrics_list]), 2),
'avg_mape': round(np.mean([m.get('mape', 0) for m in metrics_list]), 2),
'avg_r2': round(np.mean([m.get('r2', 0) for m in metrics_list]), 3),
2025-08-08 09:08:41 +02:00
'avg_training_time': round(np.mean([r.get('training_time_seconds', 0) for r in successful_results]), 2)
}
2025-07-28 19:28:39 +02:00
# Calculate data quality insights
data_points_list = [r.get('data_points', 0) for r in training_results.values()]
2025-08-08 09:08:41 +02:00
# Get database statistics
try:
# Get tenant model count from repository
tenant_models = await repos['model'].get_models_by_tenant(tenant_id)
models_created = [r.get('model_record_id') for r in successful_results if r.get('model_record_id')]
db_stats = {
'total_tenant_models': len(tenant_models),
'models_created_this_job': len(models_created),
'total_db_records': len(models_created),
'performance_metrics_created': len(models_created), # One per model
'artifacts_created': len([r for r in successful_results if r.get('model_info', {}).get('model_path')])
}
except Exception as e:
logger.warning("Failed to get database statistics", error=str(e))
db_stats = {
'total_tenant_models': 0,
'models_created_this_job': 0,
'total_db_records': 0,
'performance_metrics_created': 0,
'artifacts_created': 0
}
# Build models_created with proper model result structure
models_created = {}
for product, result in training_results.items():
if result.get('status') == 'success' and result.get('model_info'):
model_info = result['model_info']
models_created[product] = {
'status': 'completed',
'model_path': model_info.get('model_path'),
'metadata_path': model_info.get('metadata_path'),
'metrics': model_info.get('training_metrics', {}),
'hyperparameters': model_info.get('hyperparameters', {}),
'features_used': model_info.get('features_used', []),
'data_points': result.get('data_points', 0),
'data_quality_score': model_info.get('data_quality_score', 100.0),
'model_record_id': result.get('model_record_id')
}
enhanced_summary = {
2025-07-19 16:59:37 +02:00
'total_products': total_products,
'successful_products': successful_products,
'failed_products': failed_products,
'skipped_products': skipped_products,
'success_rate': round(successful_products / total_products * 100, 2) if total_products > 0 else 0,
2025-08-08 09:08:41 +02:00
'enhanced_average_metrics': avg_metrics,
'enhanced_data_summary': {
2025-07-28 19:28:39 +02:00
'total_data_points': sum(data_points_list),
'avg_data_points_per_product': round(np.mean(data_points_list), 1) if data_points_list else 0,
'min_data_points': min(data_points_list) if data_points_list else 0,
'max_data_points': max(data_points_list) if data_points_list else 0
2025-08-08 09:08:41 +02:00
},
'database_statistics': db_stats,
'models_created': models_created
}
# Add database statistics to the summary
enhanced_summary.update(db_stats)
return enhanced_summary
async def _validate_input_data(self, sales_df: pd.DataFrame, tenant_id: str):
"""Validate input sales data with enhanced error reporting"""
if sales_df.empty:
raise ValueError(f"No sales data provided for tenant {tenant_id}")
# Handle quantity column mapping
if 'quantity_sold' in sales_df.columns and 'quantity' not in sales_df.columns:
sales_df['quantity'] = sales_df['quantity_sold']
logger.info("Mapped quantity column",
from_column='quantity_sold',
to_column='quantity')
2025-08-15 17:53:59 +02:00
required_columns = ['date', 'inventory_product_id', 'quantity']
2025-08-08 09:08:41 +02:00
missing_columns = [col for col in required_columns if col not in sales_df.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
# Check for valid dates
try:
sales_df['date'] = pd.to_datetime(sales_df['date'])
except Exception:
raise ValueError("Invalid date format in sales data")
# Check for valid quantities
if not sales_df['quantity'].dtype in ['int64', 'float64']:
try:
sales_df['quantity'] = pd.to_numeric(sales_df['quantity'], errors='coerce')
except Exception:
raise ValueError("Quantity column must be numeric")
2025-11-05 13:34:56 +01:00
async def _categorize_all_products(
self,
sales_df: pd.DataFrame,
processed_data: Dict[str, pd.DataFrame]
) -> Dict[str, ProductCategory]:
"""
Categorize all products for category-specific forecasting.
Args:
sales_df: Raw sales data with product names
processed_data: Processed data by product ID
Returns:
Dict mapping inventory_product_id to ProductCategory
"""
product_categories = {}
for inventory_product_id in processed_data.keys():
try:
# Get product name from sales data (if available)
product_sales = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
# Extract product name (try multiple possible column names)
product_name = "unknown"
for name_col in ['product_name', 'name', 'item_name']:
if name_col in product_sales.columns and not product_sales[name_col].empty:
product_name = product_sales[name_col].iloc[0]
break
# Prepare sales data for pattern analysis
sales_for_analysis = product_sales[['date', 'quantity']].copy() if 'date' in product_sales.columns else None
# Categorize product
category = self.product_categorizer.categorize_product(
product_name=str(product_name),
product_id=inventory_product_id,
sales_data=sales_for_analysis
)
product_categories[inventory_product_id] = category
logger.debug("Product categorized",
inventory_product_id=inventory_product_id,
product_name=product_name,
category=category.value)
except Exception as e:
logger.warning(f"Failed to categorize product {inventory_product_id}: {e}")
product_categories[inventory_product_id] = ProductCategory.UNKNOWN
return product_categories
2025-08-08 09:08:41 +02:00
async def evaluate_model_performance_enhanced(self,
tenant_id: str,
2025-08-14 16:47:34 +02:00
inventory_product_id: str,
2025-08-08 09:08:41 +02:00
model_path: str,
test_dataset: TrainingDataSet) -> Dict[str, Any]:
"""
Enhanced model evaluation with repository integration.
"""
try:
logger.info("Enhanced model evaluation starting",
tenant_id=tenant_id,
2025-08-15 17:53:59 +02:00
inventory_product_id=inventory_product_id)
2025-08-08 09:08:41 +02:00
# Get database session and repositories
async with self.database_manager.get_session() as db_session:
repos = await self._get_repositories(db_session)
# Convert test data to DataFrames
test_sales_df = pd.DataFrame(test_dataset.sales_data)
test_weather_df = pd.DataFrame(test_dataset.weather_data)
test_traffic_df = pd.DataFrame(test_dataset.traffic_data)
# Filter for specific product
2025-08-14 16:47:34 +02:00
product_test_sales = test_sales_df[test_sales_df['inventory_product_id'] == inventory_product_id].copy()
2025-08-08 09:08:41 +02:00
if product_test_sales.empty:
2025-08-14 16:47:34 +02:00
raise ValueError(f"No test data found for product: {inventory_product_id}")
2025-08-08 09:08:41 +02:00
# Process test data using enhanced processor
processed_test_data = await self.enhanced_data_processor.prepare_training_data(
sales_data=product_test_sales,
weather_data=test_weather_df,
traffic_data=test_traffic_df,
2025-08-14 16:47:34 +02:00
inventory_product_id=inventory_product_id,
poi_features=test_dataset.poi_features, # POI features for location-based forecasting
2025-08-08 09:08:41 +02:00
tenant_id=tenant_id
)
# Create future dataframe for prediction
future_dates = processed_test_data[['ds']].copy()
# Add regressor columns
regressor_columns = [col for col in processed_test_data.columns if col not in ['ds', 'y']]
for col in regressor_columns:
future_dates[col] = processed_test_data[col]
# Generate predictions
forecast = await self.prophet_manager.generate_forecast(
model_path=model_path,
future_dates=future_dates,
regressor_columns=regressor_columns
)
# Calculate performance metrics
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
y_true = processed_test_data['y'].values
y_pred = forecast['yhat'].values
# Ensure arrays are the same length
min_len = min(len(y_true), len(y_pred))
y_true = y_true[:min_len]
y_pred = y_pred[:min_len]
metrics = {
"mae": float(mean_absolute_error(y_true, y_pred)),
"rmse": float(np.sqrt(mean_squared_error(y_true, y_pred))),
"r2_score": float(r2_score(y_true, y_pred))
}
# Calculate MAPE safely
non_zero_mask = y_true > 0.1
if np.sum(non_zero_mask) > 0:
mape = np.mean(np.abs((y_true[non_zero_mask] - y_pred[non_zero_mask]) / y_true[non_zero_mask])) * 100
metrics["mape"] = float(min(mape, 200)) # Cap at 200%
else:
metrics["mape"] = 100.0
# Store evaluation metrics in repository
2025-08-14 16:47:34 +02:00
model_records = await repos['model'].get_models_by_product(tenant_id, inventory_product_id)
2025-08-08 09:08:41 +02:00
if model_records:
latest_model = max(model_records, key=lambda x: x.created_at)
await self._create_performance_metrics(
2025-08-14 16:47:34 +02:00
repos, latest_model.id, tenant_id, inventory_product_id, metrics
2025-08-08 09:08:41 +02:00
)
result = {
"tenant_id": tenant_id,
2025-08-14 16:47:34 +02:00
"inventory_product_id": inventory_product_id,
2025-08-08 09:08:41 +02:00
"enhanced_evaluation_metrics": metrics,
"test_samples": len(processed_test_data),
"prediction_samples": len(forecast),
"test_period": {
"start": test_dataset.date_range.start.isoformat(),
"end": test_dataset.date_range.end.isoformat()
},
"evaluated_at": datetime.now().isoformat(),
"repository_integration": {
"metrics_stored": True,
"model_record_found": len(model_records) > 0 if model_records else False
}
}
return result
except Exception as e:
logger.error("Enhanced model evaluation failed", error=str(e))
raise