Files
bakery-ia/services/training/app/ml/trainer.py
2025-08-04 21:46:12 +02:00

525 lines
23 KiB
Python

# services/training/app/ml/trainer.py
"""
ML Trainer - Main ML pipeline coordinator
Receives prepared data and orchestrates the complete ML training process
"""
from typing import Dict, List, Any, Optional
import pandas as pd
import numpy as np
from datetime import datetime
import logging
import uuid
import time
from datetime import datetime
from app.ml.data_processor import BakeryDataProcessor
from app.ml.prophet_manager import BakeryProphetManager
from app.services.training_orchestrator import TrainingDataSet
from app.core.config import settings
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.messaging import TrainingStatusPublisher
logger = logging.getLogger(__name__)
class BakeryMLTrainer:
"""
Main ML trainer that orchestrates the complete ML training pipeline.
Receives prepared TrainingDataSet and coordinates data processing and model training.
"""
def __init__(self, db_session: AsyncSession = None):
self.data_processor = BakeryDataProcessor()
self.prophet_manager = BakeryProphetManager(db_session=db_session)
async def train_tenant_models(self,
tenant_id: str,
training_dataset: TrainingDataSet,
job_id: Optional[str] = None) -> Dict[str, Any]:
"""
Train models for all products using prepared training dataset.
Args:
tenant_id: Tenant identifier
training_dataset: Prepared training dataset with aligned dates
job_id: Training job identifier
Returns:
Dictionary with training results for each product
"""
if not job_id:
job_id = f"ml_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info(f"Starting ML training pipeline {job_id} for tenant {tenant_id}")
self.status_publisher = TrainingStatusPublisher(job_id, tenant_id)
try:
# Convert sales data to DataFrame
sales_df = pd.DataFrame(training_dataset.sales_data)
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
# Validate input data
await self._validate_input_data(sales_df, tenant_id)
# Get unique products from the sales data
products = sales_df['product_name'].unique().tolist()
logger.info(f"Training models for {len(products)} products: {products}")
self.status_publisher.products_total = len(products)
# Process data for each product
logger.info("Processing data for all products...")
processed_data = await self._process_all_products(
sales_df, weather_df, traffic_df, products
)
await self.status_publisher.progress_update(
progress=20,
step="feature_engineering",
step_details="Processing features for all products"
)
# Train models for each processed product
logger.info("Training models for all products...")
training_results = await self._train_all_models(
tenant_id, processed_data, job_id
)
# Calculate overall training summary
summary = self._calculate_training_summary(training_results)
await self.status_publisher.progress_update(
progress=90,
step="model_validation",
step_details="Validating model performance"
)
result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "completed",
"products_trained": len([r for r in training_results.values() if r.get('status') == 'success']),
"products_failed": len([r for r in training_results.values() if r.get('status') == 'error']),
"products_skipped": len([r for r in training_results.values() if r.get('status') == 'skipped']),
"total_products": len(products),
"training_results": training_results,
"summary": summary,
"data_info": {
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat(),
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
},
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
},
"completed_at": datetime.now().isoformat()
}
logger.info(f"ML training pipeline {job_id} completed successfully")
return result
except Exception as e:
logger.error(f"ML training pipeline {job_id} failed: {str(e)}")
raise
async def train_single_product_model(self,
tenant_id: str,
product_name: str,
training_dataset: TrainingDataSet,
job_id: Optional[str] = None) -> Dict[str, Any]:
"""
Train model for a single product using prepared training dataset.
Args:
tenant_id: Tenant identifier
product_name: Product name
training_dataset: Prepared training dataset
job_id: Training job identifier
Returns:
Training result for the product
"""
if not job_id:
job_id = f"single_ml_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
logger.info(f"Starting single product ML training {job_id} for {product_name}")
try:
# Convert training data to DataFrames
sales_df = pd.DataFrame(training_dataset.sales_data)
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
# Filter sales data for the specific product
product_sales = sales_df[sales_df['product_name'] == product_name].copy()
# Validate product data
if product_sales.empty:
raise ValueError(f"No sales data found for product: {product_name}")
# Process data for this specific product
processed_data = await self.data_processor.prepare_training_data(
sales_data=product_sales,
weather_data=weather_df,
traffic_data=traffic_df,
product_name=product_name
)
# Train the model
model_info = await self.prophet_manager.train_bakery_model(
tenant_id=tenant_id,
product_name=product_name,
df=processed_data,
job_id=job_id
)
result = {
"job_id": job_id,
"tenant_id": tenant_id,
"product_name": product_name,
"status": "success",
"model_info": model_info,
"data_points": len(processed_data),
"data_info": {
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat(),
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
},
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
},
"completed_at": datetime.now().isoformat()
}
logger.info(f"Single product ML training {job_id} completed successfully")
return result
except Exception as e:
logger.error(f"Single product ML training {job_id} failed: {str(e)}")
raise
async def evaluate_model_performance(self,
tenant_id: str,
product_name: str,
model_path: str,
test_dataset: TrainingDataSet) -> Dict[str, Any]:
"""
Evaluate model performance using test dataset.
Args:
tenant_id: Tenant identifier
product_name: Product name
model_path: Path to the trained model
test_dataset: Test dataset for evaluation
Returns:
Performance metrics
"""
try:
logger.info(f"Evaluating model performance for {product_name}")
# Convert test data to DataFrames
test_sales_df = pd.DataFrame(test_dataset.sales_data)
test_weather_df = pd.DataFrame(test_dataset.weather_data)
test_traffic_df = pd.DataFrame(test_dataset.traffic_data)
# Filter for specific product
product_test_sales = test_sales_df[test_sales_df['product_name'] == product_name].copy()
if product_test_sales.empty:
raise ValueError(f"No test data found for product: {product_name}")
# Process test data
processed_test_data = await self.data_processor.prepare_training_data(
sales_data=product_test_sales,
weather_data=test_weather_df,
traffic_data=test_traffic_df,
product_name=product_name
)
# Create future dataframe for prediction
future_dates = processed_test_data[['ds']].copy()
# Add regressor columns
regressor_columns = [col for col in processed_test_data.columns if col not in ['ds', 'y']]
for col in regressor_columns:
future_dates[col] = processed_test_data[col]
# Generate predictions
forecast = await self.prophet_manager.generate_forecast(
model_path=model_path,
future_dates=future_dates,
regressor_columns=regressor_columns
)
# Calculate performance metrics
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
y_true = processed_test_data['y'].values
y_pred = forecast['yhat'].values
# Ensure arrays are the same length
min_len = min(len(y_true), len(y_pred))
y_true = y_true[:min_len]
y_pred = y_pred[:min_len]
metrics = {
"mae": float(mean_absolute_error(y_true, y_pred)),
"rmse": float(np.sqrt(mean_squared_error(y_true, y_pred))),
"r2_score": float(r2_score(y_true, y_pred))
}
# Calculate MAPE safely
non_zero_mask = y_true > 0.1
if np.sum(non_zero_mask) > 0:
mape = np.mean(np.abs((y_true[non_zero_mask] - y_pred[non_zero_mask]) / y_true[non_zero_mask])) * 100
metrics["mape"] = float(min(mape, 200)) # Cap at 200%
else:
metrics["mape"] = 100.0
result = {
"tenant_id": tenant_id,
"product_name": product_name,
"evaluation_metrics": metrics,
"test_samples": len(processed_test_data),
"prediction_samples": len(forecast),
"test_period": {
"start": test_dataset.date_range.start.isoformat(),
"end": test_dataset.date_range.end.isoformat()
},
"evaluated_at": datetime.now().isoformat()
}
return result
except Exception as e:
logger.error(f"Model evaluation failed: {str(e)}")
raise
async def _validate_input_data(self, sales_df: pd.DataFrame, tenant_id: str):
"""Validate input sales data"""
if sales_df.empty:
raise ValueError(f"No sales data provided for tenant {tenant_id}")
# Handle quantity column mapping
if 'quantity_sold' in sales_df.columns and 'quantity' not in sales_df.columns:
sales_df['quantity'] = sales_df['quantity_sold']
logger.info("Mapped 'quantity_sold' to 'quantity' column")
required_columns = ['date', 'product_name', 'quantity']
missing_columns = [col for col in required_columns if col not in sales_df.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
# Check for valid dates
try:
sales_df['date'] = pd.to_datetime(sales_df['date'])
except Exception:
raise ValueError("Invalid date format in sales data")
# Check for valid quantities
if not sales_df['quantity'].dtype in ['int64', 'float64']:
try:
sales_df['quantity'] = pd.to_numeric(sales_df['quantity'], errors='coerce')
except Exception:
raise ValueError("Quantity column must be numeric")
async def _process_all_products(self,
sales_df: pd.DataFrame,
weather_df: pd.DataFrame,
traffic_df: pd.DataFrame,
products: List[str]) -> Dict[str, pd.DataFrame]:
"""Process data for all products using the data processor"""
processed_data = {}
for product_name in products:
try:
logger.info(f"Processing data for product: {product_name}")
# Filter sales data for this product
product_sales = sales_df[sales_df['product_name'] == product_name].copy()
if product_sales.empty:
logger.warning(f"No sales data found for product: {product_name}")
continue
# Use data processor to prepare training data
processed_product_data = await self.data_processor.prepare_training_data(
sales_data=product_sales,
weather_data=weather_df,
traffic_data=traffic_df,
product_name=product_name
)
processed_data[product_name] = processed_product_data
logger.info(f"Processed {len(processed_product_data)} data points for {product_name}")
except Exception as e:
logger.error(f"Failed to process data for {product_name}: {str(e)}")
# Continue with other products
continue
return processed_data
def calculate_estimated_time_remaining(self, processing_times: List[float], completed: int, total: int) -> int:
"""
Calculate estimated time remaining based on actual processing times
Args:
processing_times: List of processing times for completed items (in seconds)
completed: Number of items completed so far
total: Total number of items to process
Returns:
Estimated time remaining in minutes
"""
if not processing_times or completed >= total:
return 0
# Calculate average processing time
avg_time_per_item = sum(processing_times) / len(processing_times)
# Use weighted average giving more weight to recent processing times
if len(processing_times) > 3:
# Use last 3 items for more accurate recent performance
recent_times = processing_times[-3:]
recent_avg = sum(recent_times) / len(recent_times)
# Weighted average: 70% recent, 30% overall
avg_time_per_item = (recent_avg * 0.7) + (avg_time_per_item * 0.3)
# Calculate remaining items and estimated time
remaining_items = total - completed
estimated_seconds = remaining_items * avg_time_per_item
# Convert to minutes and round up
estimated_minutes = max(1, int(estimated_seconds / 60) + (1 if estimated_seconds % 60 > 0 else 0))
return estimated_minutes
async def _train_all_models(self,
tenant_id: str,
processed_data: Dict[str, pd.DataFrame],
job_id: str) -> Dict[str, Any]:
"""Train models for all processed products using Prophet manager"""
training_results = {}
i = 0
total_products = len(processed_data)
base_progress = 45
max_progress = 85
for product_name, product_data in processed_data.items():
product_start_time = time.time()
try:
logger.info(f"Training model for product: {product_name}")
# Check if we have enough data
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
training_results[product_name] = {
'status': 'skipped',
'reason': 'insufficient_data',
'data_points': len(product_data),
'min_required': settings.MIN_TRAINING_DATA_DAYS,
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
}
logger.warning(f"Skipping {product_name}: insufficient data ({len(product_data)} < {settings.MIN_TRAINING_DATA_DAYS})")
continue
# Train the model using Prophet manager
model_info = await self.prophet_manager.train_bakery_model(
tenant_id=tenant_id,
product_name=product_name,
df=product_data,
job_id=job_id
)
training_results[product_name] = {
'status': 'success',
'model_info': model_info,
'data_points': len(product_data),
'trained_at': datetime.now().isoformat()
}
logger.info(f"Successfully trained model for {product_name}")
completed_products = i + 1
i = i + 1
progress = base_progress + int((completed_products / total_products) * (max_progress - base_progress))
if self.status_publisher:
# Update products completed for accurate tracking
self.status_publisher.products_completed = completed_products
await self.status_publisher.progress_update(
progress=progress,
step="model_training",
current_product=product_name,
step_details=f"Completed training for {product_name}"
)
except Exception as e:
logger.error(f"Failed to train model for {product_name}: {str(e)}")
training_results[product_name] = {
'status': 'error',
'error_message': str(e),
'data_points': len(product_data) if product_data is not None else 0,
'failed_at': datetime.now().isoformat()
}
completed_products = i + 1
i = i + 1
if self.status_publisher:
self.status_publisher.products_completed = completed_products
await self.status_publisher.progress_update(
progress=progress,
step="model_training",
current_product=product_name,
step_details=f"Failed training for {product_name}: {str(e)}"
)
return training_results
def _calculate_training_summary(self, training_results: Dict[str, Any]) -> Dict[str, Any]:
"""Calculate summary statistics from training results"""
total_products = len(training_results)
successful_products = len([r for r in training_results.values() if r.get('status') == 'success'])
failed_products = len([r for r in training_results.values() if r.get('status') == 'error'])
skipped_products = len([r for r in training_results.values() if r.get('status') == 'skipped'])
# Calculate average training metrics for successful models
successful_results = [r for r in training_results.values() if r.get('status') == 'success']
avg_metrics = {}
if successful_results:
metrics_list = [r['model_info'].get('training_metrics', {}) for r in successful_results]
if metrics_list and all(metrics_list):
avg_metrics = {
'avg_mae': round(np.mean([m.get('mae', 0) for m in metrics_list]), 2),
'avg_rmse': round(np.mean([m.get('rmse', 0) for m in metrics_list]), 2),
'avg_mape': round(np.mean([m.get('mape', 0) for m in metrics_list]), 2),
'avg_r2': round(np.mean([m.get('r2', 0) for m in metrics_list]), 3),
'avg_improvement': round(np.mean([m.get('improvement_estimated', 0) for m in metrics_list]), 1)
}
# Calculate data quality insights
data_points_list = [r.get('data_points', 0) for r in training_results.values()]
return {
'total_products': total_products,
'successful_products': successful_products,
'failed_products': failed_products,
'skipped_products': skipped_products,
'success_rate': round(successful_products / total_products * 100, 2) if total_products > 0 else 0,
'average_metrics': avg_metrics,
'data_summary': {
'total_data_points': sum(data_points_list),
'avg_data_points_per_product': round(np.mean(data_points_list), 1) if data_points_list else 0,
'min_data_points': min(data_points_list) if data_points_list else 0,
'max_data_points': max(data_points_list) if data_points_list else 0
}
}