Files
bakery-ia/services/training/app/ml/trainer.py

520 lines
22 KiB
Python
Raw Normal View History

2025-07-19 16:59:37 +02:00
# services/training/app/ml/trainer.py
"""
2025-07-28 19:28:39 +02:00
ML Trainer - Main ML pipeline coordinator
Receives prepared data and orchestrates the complete ML training process
"""
2025-07-28 19:28:39 +02:00
from typing import Dict, List, Any, Optional
import pandas as pd
import numpy as np
2025-07-28 19:28:39 +02:00
from datetime import datetime
2025-07-19 16:59:37 +02:00
import logging
import uuid
2025-08-04 18:21:42 +02:00
import time
from datetime import datetime
2025-07-19 16:59:37 +02:00
from app.ml.data_processor import BakeryDataProcessor
2025-07-28 19:28:39 +02:00
from app.ml.prophet_manager import BakeryProphetManager
from app.services.training_orchestrator import TrainingDataSet
from app.core.config import settings
2025-07-28 19:28:39 +02:00
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed
)
2025-07-19 16:59:37 +02:00
logger = logging.getLogger(__name__)
2025-07-19 16:59:37 +02:00
class BakeryMLTrainer:
"""
2025-07-28 19:28:39 +02:00
Main ML trainer that orchestrates the complete ML training pipeline.
Receives prepared TrainingDataSet and coordinates data processing and model training.
2025-07-19 16:59:37 +02:00
"""
2025-07-28 19:28:39 +02:00
def __init__(self, db_session: AsyncSession = None):
2025-07-19 16:59:37 +02:00
self.data_processor = BakeryDataProcessor()
2025-07-28 19:28:39 +02:00
self.prophet_manager = BakeryProphetManager(db_session=db_session)
2025-07-19 16:59:37 +02:00
async def train_tenant_models(self,
tenant_id: str,
2025-07-28 19:28:39 +02:00
training_dataset: TrainingDataSet,
job_id: Optional[str] = None) -> Dict[str, Any]:
2025-07-19 16:59:37 +02:00
"""
2025-07-28 19:28:39 +02:00
Train models for all products using prepared training dataset.
2025-07-19 16:59:37 +02:00
Args:
tenant_id: Tenant identifier
2025-07-28 19:28:39 +02:00
training_dataset: Prepared training dataset with aligned dates
2025-07-19 16:59:37 +02:00
job_id: Training job identifier
Returns:
Dictionary with training results for each product
"""
if not job_id:
2025-07-28 19:28:39 +02:00
job_id = f"ml_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
2025-07-19 16:59:37 +02:00
2025-07-28 19:28:39 +02:00
logger.info(f"Starting ML training pipeline {job_id} for tenant {tenant_id}")
2025-07-19 16:59:37 +02:00
try:
2025-07-28 19:28:39 +02:00
# Convert sales data to DataFrame
sales_df = pd.DataFrame(training_dataset.sales_data)
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
2025-07-19 16:59:37 +02:00
# Validate input data
await self._validate_input_data(sales_df, tenant_id)
2025-07-28 19:28:39 +02:00
# Get unique products from the sales data
2025-07-19 16:59:37 +02:00
products = sales_df['product_name'].unique().tolist()
logger.info(f"Training models for {len(products)} products: {products}")
# Process data for each product
2025-07-28 19:28:39 +02:00
logger.info("Processing data for all products...")
2025-07-19 16:59:37 +02:00
processed_data = await self._process_all_products(
sales_df, weather_df, traffic_df, products
)
2025-08-04 18:21:42 +02:00
await publish_job_progress(job_id, tenant_id, 20, "feature_engineering", estimated_time_remaining_minutes=7)
2025-07-19 16:59:37 +02:00
2025-07-28 19:28:39 +02:00
# Train models for each processed product
logger.info("Training models for all products...")
2025-07-19 16:59:37 +02:00
training_results = await self._train_all_models(
tenant_id, processed_data, job_id
)
# Calculate overall training summary
summary = self._calculate_training_summary(training_results)
2025-08-04 18:21:42 +02:00
await publish_job_progress(job_id, tenant_id, 90, "model_validation", estimated_time_remaining_minutes=1)
2025-07-19 16:59:37 +02:00
result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "completed",
"products_trained": len([r for r in training_results.values() if r.get('status') == 'success']),
"products_failed": len([r for r in training_results.values() if r.get('status') == 'error']),
2025-07-28 19:28:39 +02:00
"products_skipped": len([r for r in training_results.values() if r.get('status') == 'skipped']),
2025-07-19 16:59:37 +02:00
"total_products": len(products),
"training_results": training_results,
"summary": summary,
2025-07-28 19:28:39 +02:00
"data_info": {
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat(),
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
},
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
},
2025-07-19 16:59:37 +02:00
"completed_at": datetime.now().isoformat()
}
2025-07-28 19:28:39 +02:00
logger.info(f"ML training pipeline {job_id} completed successfully")
2025-07-19 16:59:37 +02:00
return result
except Exception as e:
2025-07-28 19:28:39 +02:00
logger.error(f"ML training pipeline {job_id} failed: {str(e)}")
2025-07-19 16:59:37 +02:00
raise
2025-07-28 19:28:39 +02:00
async def train_single_product_model(self,
tenant_id: str,
product_name: str,
training_dataset: TrainingDataSet,
job_id: Optional[str] = None) -> Dict[str, Any]:
2025-07-19 16:59:37 +02:00
"""
2025-07-28 19:28:39 +02:00
Train model for a single product using prepared training dataset.
2025-07-19 16:59:37 +02:00
Args:
tenant_id: Tenant identifier
product_name: Product name
2025-07-28 19:28:39 +02:00
training_dataset: Prepared training dataset
2025-07-19 16:59:37 +02:00
job_id: Training job identifier
Returns:
Training result for the product
"""
if not job_id:
2025-07-28 19:28:39 +02:00
job_id = f"single_ml_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
2025-07-19 16:59:37 +02:00
2025-07-28 19:28:39 +02:00
logger.info(f"Starting single product ML training {job_id} for {product_name}")
2025-07-19 16:59:37 +02:00
try:
2025-07-28 19:28:39 +02:00
# Convert training data to DataFrames
sales_df = pd.DataFrame(training_dataset.sales_data)
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
2025-07-19 16:59:37 +02:00
# Filter sales data for the specific product
product_sales = sales_df[sales_df['product_name'] == product_name].copy()
# Validate product data
if product_sales.empty:
raise ValueError(f"No sales data found for product: {product_name}")
2025-07-28 19:28:39 +02:00
# Process data for this specific product
2025-07-19 16:59:37 +02:00
processed_data = await self.data_processor.prepare_training_data(
sales_data=product_sales,
weather_data=weather_df,
traffic_data=traffic_df,
product_name=product_name
)
# Train the model
model_info = await self.prophet_manager.train_bakery_model(
tenant_id=tenant_id,
product_name=product_name,
df=processed_data,
job_id=job_id
)
result = {
"job_id": job_id,
"tenant_id": tenant_id,
"product_name": product_name,
"status": "success",
"model_info": model_info,
"data_points": len(processed_data),
2025-07-28 19:28:39 +02:00
"data_info": {
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat(),
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
},
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
},
2025-07-19 16:59:37 +02:00
"completed_at": datetime.now().isoformat()
}
2025-07-28 19:28:39 +02:00
logger.info(f"Single product ML training {job_id} completed successfully")
2025-07-19 16:59:37 +02:00
return result
except Exception as e:
2025-07-28 19:28:39 +02:00
logger.error(f"Single product ML training {job_id} failed: {str(e)}")
2025-07-19 16:59:37 +02:00
raise
2025-07-19 16:59:37 +02:00
async def evaluate_model_performance(self,
tenant_id: str,
product_name: str,
model_path: str,
2025-07-28 19:28:39 +02:00
test_dataset: TrainingDataSet) -> Dict[str, Any]:
2025-07-19 16:59:37 +02:00
"""
2025-07-28 19:28:39 +02:00
Evaluate model performance using test dataset.
2025-07-19 16:59:37 +02:00
Args:
tenant_id: Tenant identifier
product_name: Product name
model_path: Path to the trained model
2025-07-28 19:28:39 +02:00
test_dataset: Test dataset for evaluation
2025-07-19 16:59:37 +02:00
Returns:
Performance metrics
"""
try:
logger.info(f"Evaluating model performance for {product_name}")
2025-07-28 19:28:39 +02:00
# Convert test data to DataFrames
test_sales_df = pd.DataFrame(test_dataset.sales_data)
test_weather_df = pd.DataFrame(test_dataset.weather_data)
test_traffic_df = pd.DataFrame(test_dataset.traffic_data)
# Filter for specific product
product_test_sales = test_sales_df[test_sales_df['product_name'] == product_name].copy()
2025-07-19 16:59:37 +02:00
2025-07-28 19:28:39 +02:00
if product_test_sales.empty:
raise ValueError(f"No test data found for product: {product_name}")
# Process test data
processed_test_data = await self.data_processor.prepare_training_data(
sales_data=product_test_sales,
weather_data=test_weather_df,
traffic_data=test_traffic_df,
product_name=product_name
2025-07-19 16:59:37 +02:00
)
2025-07-28 19:28:39 +02:00
# Create future dataframe for prediction
future_dates = processed_test_data[['ds']].copy()
# Add regressor columns
regressor_columns = [col for col in processed_test_data.columns if col not in ['ds', 'y']]
for col in regressor_columns:
future_dates[col] = processed_test_data[col]
2025-07-19 16:59:37 +02:00
# Generate predictions
forecast = await self.prophet_manager.generate_forecast(
model_path=model_path,
2025-07-28 19:28:39 +02:00
future_dates=future_dates,
2025-07-19 16:59:37 +02:00
regressor_columns=regressor_columns
)
2025-07-28 19:28:39 +02:00
# Calculate performance metrics
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
y_true = processed_test_data['y'].values
y_pred = forecast['yhat'].values
# Ensure arrays are the same length
min_len = min(len(y_true), len(y_pred))
y_true = y_true[:min_len]
y_pred = y_pred[:min_len]
metrics = {
"mae": float(mean_absolute_error(y_true, y_pred)),
"rmse": float(np.sqrt(mean_squared_error(y_true, y_pred))),
"r2_score": float(r2_score(y_true, y_pred))
}
# Calculate MAPE safely
non_zero_mask = y_true > 0.1
if np.sum(non_zero_mask) > 0:
mape = np.mean(np.abs((y_true[non_zero_mask] - y_pred[non_zero_mask]) / y_true[non_zero_mask])) * 100
metrics["mape"] = float(min(mape, 200)) # Cap at 200%
else:
metrics["mape"] = 100.0
2025-07-19 16:59:37 +02:00
result = {
"tenant_id": tenant_id,
"product_name": product_name,
"evaluation_metrics": metrics,
2025-07-28 19:28:39 +02:00
"test_samples": len(processed_test_data),
"prediction_samples": len(forecast),
"test_period": {
"start": test_dataset.date_range.start.isoformat(),
"end": test_dataset.date_range.end.isoformat()
},
2025-07-19 16:59:37 +02:00
"evaluated_at": datetime.now().isoformat()
}
return result
except Exception as e:
logger.error(f"Model evaluation failed: {str(e)}")
raise
async def _validate_input_data(self, sales_df: pd.DataFrame, tenant_id: str):
"""Validate input sales data"""
if sales_df.empty:
raise ValueError(f"No sales data provided for tenant {tenant_id}")
2025-07-28 19:28:39 +02:00
# Handle quantity column mapping
2025-07-27 16:29:53 +02:00
if 'quantity_sold' in sales_df.columns and 'quantity' not in sales_df.columns:
sales_df['quantity'] = sales_df['quantity_sold']
logger.info("Mapped 'quantity_sold' to 'quantity' column")
2025-07-19 16:59:37 +02:00
required_columns = ['date', 'product_name', 'quantity']
missing_columns = [col for col in required_columns if col not in sales_df.columns]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
2025-07-19 16:59:37 +02:00
# Check for valid dates
try:
sales_df['date'] = pd.to_datetime(sales_df['date'])
except Exception:
raise ValueError("Invalid date format in sales data")
2025-07-19 16:59:37 +02:00
# Check for valid quantities
if not sales_df['quantity'].dtype in ['int64', 'float64']:
2025-07-28 19:28:39 +02:00
try:
sales_df['quantity'] = pd.to_numeric(sales_df['quantity'], errors='coerce')
except Exception:
raise ValueError("Quantity column must be numeric")
2025-07-19 16:59:37 +02:00
async def _process_all_products(self,
sales_df: pd.DataFrame,
weather_df: pd.DataFrame,
traffic_df: pd.DataFrame,
products: List[str]) -> Dict[str, pd.DataFrame]:
2025-07-28 19:28:39 +02:00
"""Process data for all products using the data processor"""
2025-07-19 16:59:37 +02:00
processed_data = {}
2025-07-19 16:59:37 +02:00
for product_name in products:
try:
logger.info(f"Processing data for product: {product_name}")
# Filter sales data for this product
product_sales = sales_df[sales_df['product_name'] == product_name].copy()
2025-07-28 19:28:39 +02:00
if product_sales.empty:
logger.warning(f"No sales data found for product: {product_name}")
continue
# Use data processor to prepare training data
2025-07-19 16:59:37 +02:00
processed_product_data = await self.data_processor.prepare_training_data(
sales_data=product_sales,
weather_data=weather_df,
traffic_data=traffic_df,
product_name=product_name
)
processed_data[product_name] = processed_product_data
logger.info(f"Processed {len(processed_product_data)} data points for {product_name}")
except Exception as e:
logger.error(f"Failed to process data for {product_name}: {str(e)}")
# Continue with other products
continue
2025-07-19 16:59:37 +02:00
return processed_data
2025-08-04 18:21:42 +02:00
def calculate_estimated_time_remaining(self, processing_times: List[float], completed: int, total: int) -> int:
"""
Calculate estimated time remaining based on actual processing times
Args:
processing_times: List of processing times for completed items (in seconds)
completed: Number of items completed so far
total: Total number of items to process
Returns:
Estimated time remaining in minutes
"""
if not processing_times or completed >= total:
return 0
# Calculate average processing time
avg_time_per_item = sum(processing_times) / len(processing_times)
# Use weighted average giving more weight to recent processing times
if len(processing_times) > 3:
# Use last 3 items for more accurate recent performance
recent_times = processing_times[-3:]
recent_avg = sum(recent_times) / len(recent_times)
# Weighted average: 70% recent, 30% overall
avg_time_per_item = (recent_avg * 0.7) + (avg_time_per_item * 0.3)
# Calculate remaining items and estimated time
remaining_items = total - completed
estimated_seconds = remaining_items * avg_time_per_item
# Convert to minutes and round up
estimated_minutes = max(1, int(estimated_seconds / 60) + (1 if estimated_seconds % 60 > 0 else 0))
return estimated_minutes
2025-07-19 16:59:37 +02:00
async def _train_all_models(self,
tenant_id: str,
processed_data: Dict[str, pd.DataFrame],
job_id: str) -> Dict[str, Any]:
2025-07-28 19:28:39 +02:00
"""Train models for all processed products using Prophet manager"""
2025-07-19 16:59:37 +02:00
training_results = {}
2025-08-04 18:21:42 +02:00
total_products = len(processed_data)
base_progress = 45
max_progress = 85 # or whatever your target end progress is
products_total = 0
i = 0
start_time = time.time()
processing_times = [] # Store individual processing times
2025-07-19 16:59:37 +02:00
for product_name, product_data in processed_data.items():
2025-08-04 18:21:42 +02:00
product_start_time = time.time()
try:
2025-07-19 16:59:37 +02:00
logger.info(f"Training model for product: {product_name}")
2025-07-19 16:59:37 +02:00
# Check if we have enough data
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
training_results[product_name] = {
'status': 'skipped',
'reason': 'insufficient_data',
'data_points': len(product_data),
2025-07-28 19:28:39 +02:00
'min_required': settings.MIN_TRAINING_DATA_DAYS,
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
2025-07-19 16:59:37 +02:00
}
2025-07-28 19:28:39 +02:00
logger.warning(f"Skipping {product_name}: insufficient data ({len(product_data)} < {settings.MIN_TRAINING_DATA_DAYS})")
2025-08-04 18:21:42 +02:00
processing_times.append(time.time() - product_start_time)
2025-07-19 16:59:37 +02:00
continue
2025-07-28 19:28:39 +02:00
# Train the model using Prophet manager
2025-07-19 16:59:37 +02:00
model_info = await self.prophet_manager.train_bakery_model(
tenant_id=tenant_id,
product_name=product_name,
df=product_data,
job_id=job_id
)
training_results[product_name] = {
'status': 'success',
'model_info': model_info,
'data_points': len(product_data),
'trained_at': datetime.now().isoformat()
}
2025-07-19 16:59:37 +02:00
logger.info(f"Successfully trained model for {product_name}")
except Exception as e:
2025-07-19 16:59:37 +02:00
logger.error(f"Failed to train model for {product_name}: {str(e)}")
training_results[product_name] = {
'status': 'error',
'error_message': str(e),
2025-07-28 19:28:39 +02:00
'data_points': len(product_data) if product_data is not None else 0,
'failed_at': datetime.now().isoformat()
2025-07-19 16:59:37 +02:00
}
2025-08-04 18:21:42 +02:00
# Record processing time for this product
product_processing_time = time.time() - product_start_time
processing_times.append(product_processing_time)
i += 1
current_progress = base_progress + int((i / total_products) * (max_progress - base_progress))
# Calculate estimated time remaining
estimated_time_remaining_minutes = self.calculate_estimated_time_remaining(
processing_times, i, total_products
)
await publish_job_progress(
job_id,
tenant_id,
current_progress,
"model_training",
product_name,
products_total,
total_products,
estimated_time_remaining_minutes=estimated_time_remaining_minutes
)
2025-07-19 16:59:37 +02:00
return training_results
def _calculate_training_summary(self, training_results: Dict[str, Any]) -> Dict[str, Any]:
"""Calculate summary statistics from training results"""
total_products = len(training_results)
successful_products = len([r for r in training_results.values() if r.get('status') == 'success'])
failed_products = len([r for r in training_results.values() if r.get('status') == 'error'])
skipped_products = len([r for r in training_results.values() if r.get('status') == 'skipped'])
# Calculate average training metrics for successful models
successful_results = [r for r in training_results.values() if r.get('status') == 'success']
avg_metrics = {}
if successful_results:
metrics_list = [r['model_info'].get('training_metrics', {}) for r in successful_results]
if metrics_list and all(metrics_list):
avg_metrics = {
2025-07-28 19:28:39 +02:00
'avg_mae': round(np.mean([m.get('mae', 0) for m in metrics_list]), 2),
'avg_rmse': round(np.mean([m.get('rmse', 0) for m in metrics_list]), 2),
'avg_mape': round(np.mean([m.get('mape', 0) for m in metrics_list]), 2),
'avg_r2': round(np.mean([m.get('r2', 0) for m in metrics_list]), 3),
'avg_improvement': round(np.mean([m.get('improvement_estimated', 0) for m in metrics_list]), 1)
}
2025-07-28 19:28:39 +02:00
# Calculate data quality insights
data_points_list = [r.get('data_points', 0) for r in training_results.values()]
2025-07-19 16:59:37 +02:00
return {
'total_products': total_products,
'successful_products': successful_products,
'failed_products': failed_products,
'skipped_products': skipped_products,
'success_rate': round(successful_products / total_products * 100, 2) if total_products > 0 else 0,
2025-07-28 19:28:39 +02:00
'average_metrics': avg_metrics,
'data_summary': {
'total_data_points': sum(data_points_list),
'avg_data_points_per_product': round(np.mean(data_points_list), 1) if data_points_list else 0,
'min_data_points': min(data_points_list) if data_points_list else 0,
'max_data_points': max(data_points_list) if data_points_list else 0
}
2025-07-19 16:59:37 +02:00
}