Add all the code for training service
This commit is contained in:
408
services/training/app/ml/prophet_manager.py
Normal file
408
services/training/app/ml/prophet_manager.py
Normal file
@@ -0,0 +1,408 @@
|
||||
# services/training/app/ml/prophet_manager.py
|
||||
"""
|
||||
Enhanced Prophet Manager for Training Service
|
||||
Migrated from the monolithic backend to microservices architecture
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from prophet import Prophet
|
||||
import pickle
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
import asyncio
|
||||
import os
|
||||
import joblib
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BakeryProphetManager:
|
||||
"""
|
||||
Enhanced Prophet model manager for the training service.
|
||||
Handles training, validation, and model persistence for bakery forecasting.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.models = {} # In-memory model storage
|
||||
self.model_metadata = {} # Store model metadata
|
||||
self.feature_scalers = {} # Store feature scalers per model
|
||||
|
||||
# Ensure model storage directory exists
|
||||
os.makedirs(settings.MODEL_STORAGE_PATH, exist_ok=True)
|
||||
|
||||
async def train_bakery_model(self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
df: pd.DataFrame,
|
||||
job_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Train a Prophet model for bakery forecasting with enhanced features.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_name: Product name
|
||||
df: Training data with 'ds' and 'y' columns plus regressors
|
||||
job_id: Training job identifier
|
||||
|
||||
Returns:
|
||||
Dictionary with model information and metrics
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Training bakery model for tenant {tenant_id}, product {product_name}")
|
||||
|
||||
# Validate input data
|
||||
await self._validate_training_data(df, product_name)
|
||||
|
||||
# Prepare data for Prophet
|
||||
prophet_data = await self._prepare_prophet_data(df)
|
||||
|
||||
# Get regressor columns
|
||||
regressor_columns = self._extract_regressor_columns(prophet_data)
|
||||
|
||||
# Initialize Prophet model with bakery-specific settings
|
||||
model = self._create_prophet_model(regressor_columns)
|
||||
|
||||
# Add regressors to model
|
||||
for regressor in regressor_columns:
|
||||
if regressor in prophet_data.columns:
|
||||
model.add_regressor(regressor)
|
||||
|
||||
# Fit the model
|
||||
model.fit(prophet_data)
|
||||
|
||||
# Generate model ID and store model
|
||||
model_id = f"{job_id}_{product_name}_{uuid.uuid4().hex[:8]}"
|
||||
model_path = await self._store_model(
|
||||
tenant_id, product_name, model, model_id, prophet_data, regressor_columns
|
||||
)
|
||||
|
||||
# Calculate training metrics
|
||||
training_metrics = await self._calculate_training_metrics(model, prophet_data)
|
||||
|
||||
# Prepare model information
|
||||
model_info = {
|
||||
"model_id": model_id,
|
||||
"model_path": model_path,
|
||||
"type": "prophet",
|
||||
"training_samples": len(prophet_data),
|
||||
"features": regressor_columns,
|
||||
"hyperparameters": {
|
||||
"seasonality_mode": settings.PROPHET_SEASONALITY_MODE,
|
||||
"daily_seasonality": settings.PROPHET_DAILY_SEASONALITY,
|
||||
"weekly_seasonality": settings.PROPHET_WEEKLY_SEASONALITY,
|
||||
"yearly_seasonality": settings.PROPHET_YEARLY_SEASONALITY
|
||||
},
|
||||
"training_metrics": training_metrics,
|
||||
"trained_at": datetime.now().isoformat(),
|
||||
"data_period": {
|
||||
"start_date": prophet_data['ds'].min().isoformat(),
|
||||
"end_date": prophet_data['ds'].max().isoformat(),
|
||||
"total_days": len(prophet_data)
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"Model trained successfully for {product_name}")
|
||||
return model_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to train bakery model for {product_name}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def generate_forecast(self,
|
||||
model_path: str,
|
||||
future_dates: pd.DataFrame,
|
||||
regressor_columns: List[str]) -> pd.DataFrame:
|
||||
"""
|
||||
Generate forecast using a stored Prophet model.
|
||||
|
||||
Args:
|
||||
model_path: Path to the stored model
|
||||
future_dates: DataFrame with future dates and regressors
|
||||
regressor_columns: List of regressor column names
|
||||
|
||||
Returns:
|
||||
DataFrame with forecast results
|
||||
"""
|
||||
try:
|
||||
# Load the model
|
||||
model = joblib.load(model_path)
|
||||
|
||||
# Validate future data has required regressors
|
||||
for regressor in regressor_columns:
|
||||
if regressor not in future_dates.columns:
|
||||
logger.warning(f"Missing regressor {regressor}, filling with median")
|
||||
future_dates[regressor] = 0 # Default value
|
||||
|
||||
# Generate forecast
|
||||
forecast = model.predict(future_dates)
|
||||
|
||||
return forecast
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate forecast: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _validate_training_data(self, df: pd.DataFrame, product_name: str):
|
||||
"""Validate training data quality"""
|
||||
if df.empty:
|
||||
raise ValueError(f"No training data available for {product_name}")
|
||||
|
||||
if len(df) < settings.MIN_TRAINING_DATA_DAYS:
|
||||
raise ValueError(
|
||||
f"Insufficient training data for {product_name}: "
|
||||
f"{len(df)} days, minimum required: {settings.MIN_TRAINING_DATA_DAYS}"
|
||||
)
|
||||
|
||||
required_columns = ['ds', 'y']
|
||||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||||
if missing_columns:
|
||||
raise ValueError(f"Missing required columns: {missing_columns}")
|
||||
|
||||
# Check for valid date range
|
||||
if df['ds'].isna().any():
|
||||
raise ValueError("Invalid dates found in training data")
|
||||
|
||||
# Check for valid target values
|
||||
if df['y'].isna().all():
|
||||
raise ValueError("No valid target values found")
|
||||
|
||||
async def _prepare_prophet_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Prepare data for Prophet training"""
|
||||
prophet_data = df.copy()
|
||||
|
||||
# Ensure ds column is datetime
|
||||
prophet_data['ds'] = pd.to_datetime(prophet_data['ds'])
|
||||
|
||||
# Handle missing values in target
|
||||
if prophet_data['y'].isna().any():
|
||||
logger.warning("Filling missing target values with interpolation")
|
||||
prophet_data['y'] = prophet_data['y'].interpolate(method='linear')
|
||||
|
||||
# Remove extreme outliers (values > 3 standard deviations)
|
||||
mean_val = prophet_data['y'].mean()
|
||||
std_val = prophet_data['y'].std()
|
||||
|
||||
if std_val > 0: # Avoid division by zero
|
||||
lower_bound = mean_val - 3 * std_val
|
||||
upper_bound = mean_val + 3 * std_val
|
||||
|
||||
before_count = len(prophet_data)
|
||||
prophet_data = prophet_data[
|
||||
(prophet_data['y'] >= lower_bound) &
|
||||
(prophet_data['y'] <= upper_bound)
|
||||
]
|
||||
after_count = len(prophet_data)
|
||||
|
||||
if before_count != after_count:
|
||||
logger.info(f"Removed {before_count - after_count} outliers")
|
||||
|
||||
# Ensure chronological order
|
||||
prophet_data = prophet_data.sort_values('ds').reset_index(drop=True)
|
||||
|
||||
# Fill missing values in regressors
|
||||
numeric_columns = prophet_data.select_dtypes(include=[np.number]).columns
|
||||
for col in numeric_columns:
|
||||
if col != 'y' and prophet_data[col].isna().any():
|
||||
prophet_data[col] = prophet_data[col].fillna(prophet_data[col].median())
|
||||
|
||||
return prophet_data
|
||||
|
||||
def _extract_regressor_columns(self, df: pd.DataFrame) -> List[str]:
|
||||
"""Extract regressor columns from the dataframe"""
|
||||
excluded_columns = ['ds', 'y']
|
||||
regressor_columns = []
|
||||
|
||||
for col in df.columns:
|
||||
if col not in excluded_columns and df[col].dtype in ['int64', 'float64']:
|
||||
regressor_columns.append(col)
|
||||
|
||||
logger.info(f"Identified regressor columns: {regressor_columns}")
|
||||
return regressor_columns
|
||||
|
||||
def _create_prophet_model(self, regressor_columns: List[str]) -> Prophet:
|
||||
"""Create Prophet model with bakery-specific settings"""
|
||||
|
||||
# Get Spanish holidays
|
||||
holidays = self._get_spanish_holidays()
|
||||
|
||||
# Bakery-specific Prophet configuration
|
||||
model = Prophet(
|
||||
holidays=holidays if not holidays.empty else None,
|
||||
daily_seasonality=settings.PROPHET_DAILY_SEASONALITY,
|
||||
weekly_seasonality=settings.PROPHET_WEEKLY_SEASONALITY,
|
||||
yearly_seasonality=settings.PROPHET_YEARLY_SEASONALITY,
|
||||
seasonality_mode=settings.PROPHET_SEASONALITY_MODE,
|
||||
changepoint_prior_scale=0.05, # Conservative changepoint detection
|
||||
seasonality_prior_scale=10, # Strong seasonality for bakeries
|
||||
holidays_prior_scale=10, # Strong holiday effects
|
||||
interval_width=0.8, # 80% confidence intervals
|
||||
mcmc_samples=0, # Use MAP estimation (faster)
|
||||
uncertainty_samples=1000 # For uncertainty estimation
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def _get_spanish_holidays(self) -> pd.DataFrame:
|
||||
"""Get Spanish holidays for Prophet model"""
|
||||
try:
|
||||
# Define major Spanish holidays that affect bakery sales
|
||||
holidays_list = []
|
||||
|
||||
years = range(2020, 2030) # Cover training and prediction period
|
||||
|
||||
for year in years:
|
||||
holidays_list.extend([
|
||||
{'holiday': 'new_year', 'ds': f'{year}-01-01'},
|
||||
{'holiday': 'epiphany', 'ds': f'{year}-01-06'},
|
||||
{'holiday': 'may_day', 'ds': f'{year}-05-01'},
|
||||
{'holiday': 'assumption', 'ds': f'{year}-08-15'},
|
||||
{'holiday': 'national_day', 'ds': f'{year}-10-12'},
|
||||
{'holiday': 'all_saints', 'ds': f'{year}-11-01'},
|
||||
{'holiday': 'constitution', 'ds': f'{year}-12-06'},
|
||||
{'holiday': 'immaculate', 'ds': f'{year}-12-08'},
|
||||
{'holiday': 'christmas', 'ds': f'{year}-12-25'},
|
||||
|
||||
# Madrid specific holidays
|
||||
{'holiday': 'madrid_patron', 'ds': f'{year}-05-15'}, # San Isidro
|
||||
{'holiday': 'madrid_community', 'ds': f'{year}-05-02'},
|
||||
])
|
||||
|
||||
holidays_df = pd.DataFrame(holidays_list)
|
||||
holidays_df['ds'] = pd.to_datetime(holidays_df['ds'])
|
||||
|
||||
return holidays_df
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error creating holidays dataframe: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
async def _store_model(self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
model: Prophet,
|
||||
model_id: str,
|
||||
training_data: pd.DataFrame,
|
||||
regressor_columns: List[str]) -> str:
|
||||
"""Store model and metadata to filesystem"""
|
||||
|
||||
# Create model filename
|
||||
model_filename = f"{model_id}_prophet_model.pkl"
|
||||
model_path = os.path.join(settings.MODEL_STORAGE_PATH, model_filename)
|
||||
|
||||
# Store the model
|
||||
joblib.dump(model, model_path)
|
||||
|
||||
# Store metadata
|
||||
metadata = {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"model_id": model_id,
|
||||
"regressor_columns": regressor_columns,
|
||||
"training_samples": len(training_data),
|
||||
"training_period": {
|
||||
"start": training_data['ds'].min().isoformat(),
|
||||
"end": training_data['ds'].max().isoformat()
|
||||
},
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"model_type": "prophet",
|
||||
"file_path": model_path
|
||||
}
|
||||
|
||||
metadata_path = model_path.replace('.pkl', '_metadata.json')
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
# Store in memory for quick access
|
||||
model_key = f"{tenant_id}:{product_name}"
|
||||
self.models[model_key] = model
|
||||
self.model_metadata[model_key] = metadata
|
||||
|
||||
logger.info(f"Model stored at: {model_path}")
|
||||
return model_path
|
||||
|
||||
async def _calculate_training_metrics(self,
|
||||
model: Prophet,
|
||||
training_data: pd.DataFrame) -> Dict[str, float]:
|
||||
"""Calculate training metrics for the model"""
|
||||
try:
|
||||
# Generate in-sample predictions
|
||||
forecast = model.predict(training_data[['ds'] + [col for col in training_data.columns if col not in ['ds', 'y']]])
|
||||
|
||||
# Calculate metrics
|
||||
y_true = training_data['y'].values
|
||||
y_pred = forecast['yhat'].values
|
||||
|
||||
# Basic metrics
|
||||
mae = mean_absolute_error(y_true, y_pred)
|
||||
mse = mean_squared_error(y_true, y_pred)
|
||||
rmse = np.sqrt(mse)
|
||||
|
||||
# MAPE (Mean Absolute Percentage Error)
|
||||
mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100
|
||||
|
||||
# R-squared
|
||||
r2 = r2_score(y_true, y_pred)
|
||||
|
||||
return {
|
||||
"mae": round(mae, 2),
|
||||
"mse": round(mse, 2),
|
||||
"rmse": round(rmse, 2),
|
||||
"mape": round(mape, 2),
|
||||
"r2_score": round(r2, 4),
|
||||
"mean_actual": round(np.mean(y_true), 2),
|
||||
"mean_predicted": round(np.mean(y_pred), 2)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating training metrics: {e}")
|
||||
return {
|
||||
"mae": 0.0,
|
||||
"mse": 0.0,
|
||||
"rmse": 0.0,
|
||||
"mape": 0.0,
|
||||
"r2_score": 0.0,
|
||||
"mean_actual": 0.0,
|
||||
"mean_predicted": 0.0
|
||||
}
|
||||
|
||||
def get_model_info(self, tenant_id: str, product_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get model information for a specific tenant and product"""
|
||||
model_key = f"{tenant_id}:{product_name}"
|
||||
return self.model_metadata.get(model_key)
|
||||
|
||||
def list_models(self, tenant_id: str) -> List[Dict[str, Any]]:
|
||||
"""List all models for a tenant"""
|
||||
tenant_models = []
|
||||
|
||||
for model_key, metadata in self.model_metadata.items():
|
||||
if metadata['tenant_id'] == tenant_id:
|
||||
tenant_models.append(metadata)
|
||||
|
||||
return tenant_models
|
||||
|
||||
async def cleanup_old_models(self, days_old: int = 30):
|
||||
"""Clean up old model files"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_old)
|
||||
|
||||
for model_path in Path(settings.MODEL_STORAGE_PATH).glob("*.pkl"):
|
||||
# Check file modification time
|
||||
if model_path.stat().st_mtime < cutoff_date.timestamp():
|
||||
# Remove model and metadata files
|
||||
model_path.unlink()
|
||||
|
||||
metadata_path = model_path.with_suffix('.json')
|
||||
if metadata_path.exists():
|
||||
metadata_path.unlink()
|
||||
|
||||
logger.info(f"Cleaned up old model: {model_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during model cleanup: {e}")
|
||||
Reference in New Issue
Block a user