Files
bakery-ia/services/training/app/ml/trainer.py

643 lines
30 KiB
Python
Raw Normal View History

"""
2025-08-08 09:08:41 +02:00
Enhanced ML Trainer with Repository Pattern
Main ML pipeline coordinator using repository pattern for data access and dependency injection
"""
2025-07-28 19:28:39 +02:00
from typing import Dict, List, Any, Optional
import pandas as pd
import numpy as np
2025-07-28 19:28:39 +02:00
from datetime import datetime
2025-08-08 09:08:41 +02:00
import structlog
2025-07-19 16:59:37 +02:00
import uuid
2025-08-04 18:21:42 +02:00
import time
2025-08-08 09:08:41 +02:00
from app.ml.data_processor import EnhancedBakeryDataProcessor
2025-07-28 19:28:39 +02:00
from app.ml.prophet_manager import BakeryProphetManager
from app.services.training_orchestrator import TrainingDataSet
from app.core.config import settings
2025-08-08 09:08:41 +02:00
from shared.database.base import create_database_manager
from shared.database.transactions import transactional
from shared.database.unit_of_work import UnitOfWork
from shared.database.exceptions import DatabaseError
from app.repositories import (
ModelRepository,
TrainingLogRepository,
PerformanceRepository,
ArtifactRepository
)
2025-07-28 19:28:39 +02:00
2025-08-04 18:58:12 +02:00
from app.services.messaging import TrainingStatusPublisher
2025-08-08 09:08:41 +02:00
logger = structlog.get_logger()
2025-08-08 09:08:41 +02:00
class EnhancedBakeryMLTrainer:
2025-07-19 16:59:37 +02:00
"""
2025-08-08 09:08:41 +02:00
Enhanced ML trainer using repository pattern for data access and comprehensive tracking.
Orchestrates the complete ML training pipeline with proper database abstraction.
2025-07-19 16:59:37 +02:00
"""
2025-08-08 09:08:41 +02:00
def __init__(self, database_manager=None):
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "training-service")
self.enhanced_data_processor = EnhancedBakeryDataProcessor(self.database_manager)
self.prophet_manager = BakeryProphetManager(database_manager=self.database_manager)
async def _get_repositories(self, session):
"""Initialize repositories with session"""
return {
'model': ModelRepository(session),
'training_log': TrainingLogRepository(session),
'performance': PerformanceRepository(session),
'artifact': ArtifactRepository(session)
}
2025-07-19 16:59:37 +02:00
async def train_tenant_models(self,
tenant_id: str,
2025-07-28 19:28:39 +02:00
training_dataset: TrainingDataSet,
2025-08-08 09:08:41 +02:00
job_id: Optional[str] = None,
session=None) -> Dict[str, Any]:
2025-07-19 16:59:37 +02:00
"""
2025-08-08 09:08:41 +02:00
Train models for all products using repository pattern with enhanced tracking.
2025-07-19 16:59:37 +02:00
Args:
tenant_id: Tenant identifier
2025-07-28 19:28:39 +02:00
training_dataset: Prepared training dataset with aligned dates
2025-07-19 16:59:37 +02:00
job_id: Training job identifier
Returns:
Dictionary with training results for each product
"""
if not job_id:
2025-08-08 09:08:41 +02:00
job_id = f"enhanced_ml_{tenant_id}_{uuid.uuid4().hex[:8]}"
2025-07-19 16:59:37 +02:00
2025-08-08 09:08:41 +02:00
logger.info("Starting enhanced ML training pipeline",
job_id=job_id,
tenant_id=tenant_id)
2025-08-04 18:58:12 +02:00
self.status_publisher = TrainingStatusPublisher(job_id, tenant_id)
2025-07-19 16:59:37 +02:00
try:
2025-08-08 09:08:41 +02:00
# Get database session and repositories
async with self.database_manager.get_session() as db_session:
repos = await self._get_repositories(db_session)
# Convert sales data to DataFrame
sales_df = pd.DataFrame(training_dataset.sales_data)
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
# Validate input data
await self._validate_input_data(sales_df, tenant_id)
# Get unique products from the sales data
products = sales_df['product_name'].unique().tolist()
logger.info("Training enhanced models",
products_count=len(products),
products=products)
self.status_publisher.products_total = len(products)
# Create initial training log entry
await repos['training_log'].update_log_progress(
job_id, 5, "data_processing", "running"
)
# Process data for each product using enhanced processor
logger.info("Processing data using enhanced processor")
processed_data = await self._process_all_products_enhanced(
sales_df, weather_df, traffic_df, products, tenant_id, job_id
)
await self.status_publisher.progress_update(
progress=20,
step="feature_engineering",
step_details="Enhanced processing with repository tracking"
)
# Train models for each processed product
logger.info("Training models with repository integration")
training_results = await self._train_all_models_enhanced(
tenant_id, processed_data, job_id, repos
)
# Calculate overall training summary with enhanced metrics
summary = await self._calculate_enhanced_training_summary(
training_results, repos, tenant_id
)
await self.status_publisher.progress_update(
progress=90,
step="model_validation",
step_details="Enhanced validation with repository tracking"
)
# Create comprehensive result with repository data
result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "completed",
"products_trained": len([r for r in training_results.values() if r.get('status') == 'success']),
"products_failed": len([r for r in training_results.values() if r.get('status') == 'error']),
"products_skipped": len([r for r in training_results.values() if r.get('status') == 'skipped']),
"total_products": len(products),
"training_results": training_results,
"enhanced_summary": summary,
"models_trained": summary.get('models_created', {}),
"data_info": {
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat(),
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
},
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
2025-07-28 19:28:39 +02:00
},
2025-08-08 09:08:41 +02:00
"repository_metadata": {
"total_records_created": summary.get('total_db_records', 0),
"performance_metrics_stored": summary.get('performance_metrics_created', 0),
"artifacts_created": summary.get('artifacts_created', 0)
2025-07-28 19:28:39 +02:00
},
2025-08-08 09:08:41 +02:00
"completed_at": datetime.now().isoformat()
}
logger.info("Enhanced ML training pipeline completed successfully",
job_id=job_id,
models_created=len([r for r in training_results.values() if r.get('status') == 'success']))
return result
2025-07-19 16:59:37 +02:00
except Exception as e:
2025-08-08 09:08:41 +02:00
logger.error("Enhanced ML training pipeline failed",
job_id=job_id,
error=str(e))
2025-07-19 16:59:37 +02:00
raise
2025-08-08 09:08:41 +02:00
async def _process_all_products_enhanced(self,
sales_df: pd.DataFrame,
weather_df: pd.DataFrame,
traffic_df: pd.DataFrame,
products: List[str],
tenant_id: str,
job_id: str) -> Dict[str, pd.DataFrame]:
"""Process data for all products using enhanced processor with repository tracking"""
2025-07-19 16:59:37 +02:00
processed_data = {}
2025-07-19 16:59:37 +02:00
for product_name in products:
try:
2025-08-08 09:08:41 +02:00
logger.info("Processing data for product using enhanced processor",
product_name=product_name)
2025-07-19 16:59:37 +02:00
# Filter sales data for this product
product_sales = sales_df[sales_df['product_name'] == product_name].copy()
2025-07-28 19:28:39 +02:00
if product_sales.empty:
2025-08-08 09:08:41 +02:00
logger.warning("No sales data found for product",
product_name=product_name)
2025-07-28 19:28:39 +02:00
continue
2025-08-08 09:08:41 +02:00
# Use enhanced data processor with repository tracking
processed_product_data = await self.enhanced_data_processor.prepare_training_data(
2025-07-19 16:59:37 +02:00
sales_data=product_sales,
weather_data=weather_df,
traffic_data=traffic_df,
2025-08-08 09:08:41 +02:00
product_name=product_name,
tenant_id=tenant_id,
job_id=job_id
2025-07-19 16:59:37 +02:00
)
processed_data[product_name] = processed_product_data
2025-08-08 09:08:41 +02:00
logger.info("Enhanced processing completed",
product_name=product_name,
data_points=len(processed_product_data))
2025-07-19 16:59:37 +02:00
except Exception as e:
2025-08-08 09:08:41 +02:00
logger.error("Failed to process data using enhanced processor",
product_name=product_name,
error=str(e))
2025-07-19 16:59:37 +02:00
continue
2025-07-19 16:59:37 +02:00
return processed_data
2025-08-08 09:08:41 +02:00
async def _train_all_models_enhanced(self,
tenant_id: str,
processed_data: Dict[str, pd.DataFrame],
job_id: str,
repos: Dict) -> Dict[str, Any]:
"""Train models with enhanced repository integration"""
2025-07-19 16:59:37 +02:00
training_results = {}
2025-08-04 18:58:12 +02:00
i = 0
2025-08-04 18:21:42 +02:00
total_products = len(processed_data)
base_progress = 45
2025-08-04 18:58:12 +02:00
max_progress = 85
2025-08-04 18:21:42 +02:00
2025-07-19 16:59:37 +02:00
for product_name, product_data in processed_data.items():
2025-08-04 18:21:42 +02:00
product_start_time = time.time()
try:
2025-08-08 09:08:41 +02:00
logger.info("Training enhanced model",
product_name=product_name)
2025-07-19 16:59:37 +02:00
# Check if we have enough data
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
training_results[product_name] = {
'status': 'skipped',
'reason': 'insufficient_data',
'data_points': len(product_data),
2025-07-28 19:28:39 +02:00
'min_required': settings.MIN_TRAINING_DATA_DAYS,
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
2025-07-19 16:59:37 +02:00
}
2025-08-08 09:08:41 +02:00
logger.warning("Skipping product due to insufficient data",
product_name=product_name,
data_points=len(product_data),
min_required=settings.MIN_TRAINING_DATA_DAYS)
2025-07-19 16:59:37 +02:00
continue
2025-07-28 19:28:39 +02:00
# Train the model using Prophet manager
2025-07-19 16:59:37 +02:00
model_info = await self.prophet_manager.train_bakery_model(
tenant_id=tenant_id,
product_name=product_name,
df=product_data,
job_id=job_id
)
2025-08-08 09:08:41 +02:00
# Store model record using repository
model_record = await self._create_model_record(
repos, tenant_id, product_name, model_info, job_id, product_data
)
# Create performance metrics record
if model_info.get('training_metrics'):
await self._create_performance_metrics(
repos, model_record.id if model_record else None,
tenant_id, product_name, model_info['training_metrics']
)
2025-07-19 16:59:37 +02:00
training_results[product_name] = {
'status': 'success',
'model_info': model_info,
2025-08-08 09:08:41 +02:00
'model_record_id': model_record.id if model_record else None,
2025-07-19 16:59:37 +02:00
'data_points': len(product_data),
2025-08-08 09:08:41 +02:00
'training_time_seconds': time.time() - product_start_time,
2025-07-19 16:59:37 +02:00
'trained_at': datetime.now().isoformat()
}
2025-08-08 09:08:41 +02:00
logger.info("Successfully trained enhanced model",
product_name=product_name,
model_record_id=model_record.id if model_record else None)
2025-07-19 16:59:37 +02:00
2025-08-04 18:58:12 +02:00
completed_products = i + 1
2025-08-08 09:08:41 +02:00
i += 1
2025-08-04 18:58:12 +02:00
progress = base_progress + int((completed_products / total_products) * (max_progress - base_progress))
if self.status_publisher:
self.status_publisher.products_completed = completed_products
2025-08-04 19:17:31 +02:00
await self.status_publisher.progress_update(
2025-08-04 18:58:12 +02:00
progress=progress,
step="model_training",
current_product=product_name,
2025-08-08 09:08:41 +02:00
step_details=f"Enhanced training completed for {product_name}"
2025-08-04 18:58:12 +02:00
)
except Exception as e:
2025-08-08 09:08:41 +02:00
logger.error("Failed to train enhanced model",
product_name=product_name,
error=str(e))
2025-07-19 16:59:37 +02:00
training_results[product_name] = {
'status': 'error',
'error_message': str(e),
2025-07-28 19:28:39 +02:00
'data_points': len(product_data) if product_data is not None else 0,
2025-08-08 09:08:41 +02:00
'training_time_seconds': time.time() - product_start_time,
2025-07-28 19:28:39 +02:00
'failed_at': datetime.now().isoformat()
2025-07-19 16:59:37 +02:00
}
2025-08-04 18:58:12 +02:00
completed_products = i + 1
2025-08-08 09:08:41 +02:00
i += 1
progress = base_progress + int((completed_products / total_products) * (max_progress - base_progress))
2025-08-04 18:58:12 +02:00
if self.status_publisher:
self.status_publisher.products_completed = completed_products
await self.status_publisher.progress_update(
progress=progress,
step="model_training",
current_product=product_name,
2025-08-08 09:08:41 +02:00
step_details=f"Enhanced training failed for {product_name}: {str(e)}"
2025-08-04 18:58:12 +02:00
)
2025-08-04 18:21:42 +02:00
2025-07-19 16:59:37 +02:00
return training_results
2025-08-08 09:08:41 +02:00
async def _create_model_record(self,
repos: Dict,
tenant_id: str,
product_name: str,
model_info: Dict,
job_id: str,
processed_data: pd.DataFrame):
"""Create model record using repository"""
try:
model_data = {
"tenant_id": tenant_id,
"product_name": product_name,
"job_id": job_id,
"model_type": "enhanced_prophet",
"model_path": model_info.get("model_path"),
"metadata_path": model_info.get("metadata_path"),
"mape": model_info.get("training_metrics", {}).get("mape"),
"mae": model_info.get("training_metrics", {}).get("mae"),
"rmse": model_info.get("training_metrics", {}).get("rmse"),
"r2_score": model_info.get("training_metrics", {}).get("r2"),
"training_samples": len(processed_data),
"hyperparameters": model_info.get("hyperparameters"),
"features_used": list(processed_data.columns),
2025-08-12 18:17:30 +02:00
"normalization_params": self.enhanced_data_processor.get_scalers(), # Include scalers for prediction consistency
2025-08-08 09:08:41 +02:00
"is_active": True,
"is_production": True,
"data_quality_score": model_info.get("data_quality_score", 100.0)
}
model_record = await repos['model'].create_model(model_data)
logger.info("Created enhanced model record",
product_name=product_name,
model_id=model_record.id)
# Create artifacts for model files
if model_info.get("model_path"):
await repos['artifact'].create_artifact({
"model_id": str(model_record.id),
"tenant_id": tenant_id,
"artifact_type": "enhanced_model_file",
"file_path": model_info["model_path"],
"storage_location": "local"
})
return model_record
except Exception as e:
logger.error("Failed to create enhanced model record",
product_name=product_name,
error=str(e))
return None
async def _create_performance_metrics(self,
repos: Dict,
model_id: str,
tenant_id: str,
product_name: str,
metrics: Dict):
"""Create performance metrics record using repository"""
try:
metric_data = {
"model_id": str(model_id),
"tenant_id": tenant_id,
"product_name": product_name,
"mae": metrics.get("mae"),
"mse": metrics.get("mse"),
"rmse": metrics.get("rmse"),
"mape": metrics.get("mape"),
"r2_score": metrics.get("r2"),
"accuracy_percentage": 100 - metrics.get("mape", 0) if metrics.get("mape") else None,
"evaluation_samples": metrics.get("data_points", 0)
}
await repos['performance'].create_performance_metric(metric_data)
logger.info("Created enhanced performance metrics",
product_name=product_name,
model_id=model_id)
except Exception as e:
logger.error("Failed to create enhanced performance metrics",
product_name=product_name,
error=str(e))
async def _calculate_enhanced_training_summary(self,
training_results: Dict[str, Any],
repos: Dict,
tenant_id: str) -> Dict[str, Any]:
"""Calculate enhanced summary statistics with repository data"""
2025-07-19 16:59:37 +02:00
total_products = len(training_results)
successful_products = len([r for r in training_results.values() if r.get('status') == 'success'])
failed_products = len([r for r in training_results.values() if r.get('status') == 'error'])
skipped_products = len([r for r in training_results.values() if r.get('status') == 'skipped'])
# Calculate average training metrics for successful models
successful_results = [r for r in training_results.values() if r.get('status') == 'success']
avg_metrics = {}
if successful_results:
metrics_list = [r['model_info'].get('training_metrics', {}) for r in successful_results]
if metrics_list and all(metrics_list):
avg_metrics = {
2025-07-28 19:28:39 +02:00
'avg_mae': round(np.mean([m.get('mae', 0) for m in metrics_list]), 2),
'avg_rmse': round(np.mean([m.get('rmse', 0) for m in metrics_list]), 2),
'avg_mape': round(np.mean([m.get('mape', 0) for m in metrics_list]), 2),
'avg_r2': round(np.mean([m.get('r2', 0) for m in metrics_list]), 3),
2025-08-08 09:08:41 +02:00
'avg_training_time': round(np.mean([r.get('training_time_seconds', 0) for r in successful_results]), 2)
}
2025-07-28 19:28:39 +02:00
# Calculate data quality insights
data_points_list = [r.get('data_points', 0) for r in training_results.values()]
2025-08-08 09:08:41 +02:00
# Get database statistics
try:
# Get tenant model count from repository
tenant_models = await repos['model'].get_models_by_tenant(tenant_id)
models_created = [r.get('model_record_id') for r in successful_results if r.get('model_record_id')]
db_stats = {
'total_tenant_models': len(tenant_models),
'models_created_this_job': len(models_created),
'total_db_records': len(models_created),
'performance_metrics_created': len(models_created), # One per model
'artifacts_created': len([r for r in successful_results if r.get('model_info', {}).get('model_path')])
}
except Exception as e:
logger.warning("Failed to get database statistics", error=str(e))
db_stats = {
'total_tenant_models': 0,
'models_created_this_job': 0,
'total_db_records': 0,
'performance_metrics_created': 0,
'artifacts_created': 0
}
# Build models_created with proper model result structure
models_created = {}
for product, result in training_results.items():
if result.get('status') == 'success' and result.get('model_info'):
model_info = result['model_info']
models_created[product] = {
'status': 'completed',
'model_path': model_info.get('model_path'),
'metadata_path': model_info.get('metadata_path'),
'metrics': model_info.get('training_metrics', {}),
'hyperparameters': model_info.get('hyperparameters', {}),
'features_used': model_info.get('features_used', []),
'data_points': result.get('data_points', 0),
'data_quality_score': model_info.get('data_quality_score', 100.0),
'model_record_id': result.get('model_record_id')
}
enhanced_summary = {
2025-07-19 16:59:37 +02:00
'total_products': total_products,
'successful_products': successful_products,
'failed_products': failed_products,
'skipped_products': skipped_products,
'success_rate': round(successful_products / total_products * 100, 2) if total_products > 0 else 0,
2025-08-08 09:08:41 +02:00
'enhanced_average_metrics': avg_metrics,
'enhanced_data_summary': {
2025-07-28 19:28:39 +02:00
'total_data_points': sum(data_points_list),
'avg_data_points_per_product': round(np.mean(data_points_list), 1) if data_points_list else 0,
'min_data_points': min(data_points_list) if data_points_list else 0,
'max_data_points': max(data_points_list) if data_points_list else 0
2025-08-08 09:08:41 +02:00
},
'database_statistics': db_stats,
'models_created': models_created
}
# Add database statistics to the summary
enhanced_summary.update(db_stats)
return enhanced_summary
async def _validate_input_data(self, sales_df: pd.DataFrame, tenant_id: str):
"""Validate input sales data with enhanced error reporting"""
if sales_df.empty:
raise ValueError(f"No sales data provided for tenant {tenant_id}")
# Handle quantity column mapping
if 'quantity_sold' in sales_df.columns and 'quantity' not in sales_df.columns:
sales_df['quantity'] = sales_df['quantity_sold']
logger.info("Mapped quantity column",
from_column='quantity_sold',
to_column='quantity')
required_columns = ['date', 'product_name', 'quantity']
missing_columns = [col for col in required_columns if col not in sales_df.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
# Check for valid dates
try:
sales_df['date'] = pd.to_datetime(sales_df['date'])
except Exception:
raise ValueError("Invalid date format in sales data")
# Check for valid quantities
if not sales_df['quantity'].dtype in ['int64', 'float64']:
try:
sales_df['quantity'] = pd.to_numeric(sales_df['quantity'], errors='coerce')
except Exception:
raise ValueError("Quantity column must be numeric")
async def evaluate_model_performance_enhanced(self,
tenant_id: str,
product_name: str,
model_path: str,
test_dataset: TrainingDataSet) -> Dict[str, Any]:
"""
Enhanced model evaluation with repository integration.
"""
try:
logger.info("Enhanced model evaluation starting",
tenant_id=tenant_id,
product_name=product_name)
# Get database session and repositories
async with self.database_manager.get_session() as db_session:
repos = await self._get_repositories(db_session)
# Convert test data to DataFrames
test_sales_df = pd.DataFrame(test_dataset.sales_data)
test_weather_df = pd.DataFrame(test_dataset.weather_data)
test_traffic_df = pd.DataFrame(test_dataset.traffic_data)
# Filter for specific product
product_test_sales = test_sales_df[test_sales_df['product_name'] == product_name].copy()
if product_test_sales.empty:
raise ValueError(f"No test data found for product: {product_name}")
# Process test data using enhanced processor
processed_test_data = await self.enhanced_data_processor.prepare_training_data(
sales_data=product_test_sales,
weather_data=test_weather_df,
traffic_data=test_traffic_df,
product_name=product_name,
tenant_id=tenant_id
)
# Create future dataframe for prediction
future_dates = processed_test_data[['ds']].copy()
# Add regressor columns
regressor_columns = [col for col in processed_test_data.columns if col not in ['ds', 'y']]
for col in regressor_columns:
future_dates[col] = processed_test_data[col]
# Generate predictions
forecast = await self.prophet_manager.generate_forecast(
model_path=model_path,
future_dates=future_dates,
regressor_columns=regressor_columns
)
# Calculate performance metrics
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
y_true = processed_test_data['y'].values
y_pred = forecast['yhat'].values
# Ensure arrays are the same length
min_len = min(len(y_true), len(y_pred))
y_true = y_true[:min_len]
y_pred = y_pred[:min_len]
metrics = {
"mae": float(mean_absolute_error(y_true, y_pred)),
"rmse": float(np.sqrt(mean_squared_error(y_true, y_pred))),
"r2_score": float(r2_score(y_true, y_pred))
}
# Calculate MAPE safely
non_zero_mask = y_true > 0.1
if np.sum(non_zero_mask) > 0:
mape = np.mean(np.abs((y_true[non_zero_mask] - y_pred[non_zero_mask]) / y_true[non_zero_mask])) * 100
metrics["mape"] = float(min(mape, 200)) # Cap at 200%
else:
metrics["mape"] = 100.0
# Store evaluation metrics in repository
model_records = await repos['model'].get_models_by_product(tenant_id, product_name)
if model_records:
latest_model = max(model_records, key=lambda x: x.created_at)
await self._create_performance_metrics(
repos, latest_model.id, tenant_id, product_name, metrics
)
result = {
"tenant_id": tenant_id,
"product_name": product_name,
"enhanced_evaluation_metrics": metrics,
"test_samples": len(processed_test_data),
"prediction_samples": len(forecast),
"test_period": {
"start": test_dataset.date_range.start.isoformat(),
"end": test_dataset.date_range.end.isoformat()
},
"evaluated_at": datetime.now().isoformat(),
"repository_integration": {
"metrics_stored": True,
"model_record_found": len(model_records) > 0 if model_records else False
}
}
return result
except Exception as e:
logger.error("Enhanced model evaluation failed", error=str(e))
raise
# Legacy compatibility alias
BakeryMLTrainer = EnhancedBakeryMLTrainer