Files
bakery-ia/services/training/app/ml/trainer.py
Urtzi Alfaro 4073222888 Fix imports
2025-07-18 14:41:39 +02:00

174 lines
6.5 KiB
Python

"""
ML Training implementation
"""
import asyncio
import structlog
from typing import Dict, Any, List
import pandas as pd
from datetime import datetime
import joblib
import os
from prophet import Prophet
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from app.core.config import settings
logger = structlog.get_logger()
class MLTrainer:
"""ML training implementation"""
def __init__(self):
self.model_storage_path = settings.MODEL_STORAGE_PATH
os.makedirs(self.model_storage_path, exist_ok=True)
async def train_models(self, training_data: Dict[str, Any], job_id: str, db) -> Dict[str, Any]:
"""Train models for all products"""
models_result = {}
# Get sales data
sales_data = training_data.get("sales_data", [])
external_data = training_data.get("external_data", {})
# Group by product
products_data = self._group_by_product(sales_data)
# Train model for each product
for product_name, product_sales in products_data.items():
try:
model_result = await self._train_product_model(
product_name,
product_sales,
external_data,
job_id
)
models_result[product_name] = model_result
except Exception as e:
logger.error(f"Failed to train model for {product_name}: {e}")
continue
return models_result
def _group_by_product(self, sales_data: List[Dict]) -> Dict[str, List[Dict]]:
"""Group sales data by product"""
products = {}
for sale in sales_data:
product_name = sale.get("product_name")
if product_name not in products:
products[product_name] = []
products[product_name].append(sale)
return products
async def _train_product_model(self, product_name: str, sales_data: List[Dict], external_data: Dict, job_id: str) -> Dict[str, Any]:
"""Train Prophet model for a single product"""
# Convert to DataFrame
df = pd.DataFrame(sales_data)
df['date'] = pd.to_datetime(df['date'])
# Aggregate daily sales
daily_sales = df.groupby('date')['quantity_sold'].sum().reset_index()
daily_sales.columns = ['ds', 'y']
# Add external features
daily_sales = self._add_external_features(daily_sales, external_data)
# Train Prophet model
model = Prophet(
seasonality_mode=settings.PROPHET_SEASONALITY_MODE,
daily_seasonality=settings.PROPHET_DAILY_SEASONALITY,
weekly_seasonality=settings.PROPHET_WEEKLY_SEASONALITY,
yearly_seasonality=settings.PROPHET_YEARLY_SEASONALITY
)
# Add regressors
model.add_regressor('temperature')
model.add_regressor('humidity')
model.add_regressor('precipitation')
model.add_regressor('traffic_volume')
# Fit model
model.fit(daily_sales)
# Save model
model_path = os.path.join(
self.model_storage_path,
f"{job_id}_{product_name}_prophet_model.pkl"
)
joblib.dump(model, model_path)
return {
"type": "prophet",
"path": model_path,
"training_samples": len(daily_sales),
"features": ["temperature", "humidity", "precipitation", "traffic_volume"],
"hyperparameters": {
"seasonality_mode": settings.PROPHET_SEASONALITY_MODE,
"daily_seasonality": settings.PROPHET_DAILY_SEASONALITY,
"weekly_seasonality": settings.PROPHET_WEEKLY_SEASONALITY,
"yearly_seasonality": settings.PROPHET_YEARLY_SEASONALITY
}
}
def _add_external_features(self, daily_sales: pd.DataFrame, external_data: Dict) -> pd.DataFrame:
"""Add external features to sales data"""
# Add weather data
weather_data = external_data.get("weather", [])
if weather_data:
weather_df = pd.DataFrame(weather_data)
weather_df['ds'] = pd.to_datetime(weather_df['date'])
daily_sales = daily_sales.merge(weather_df[['ds', 'temperature', 'humidity', 'precipitation']], on='ds', how='left')
# Add traffic data
traffic_data = external_data.get("traffic", [])
if traffic_data:
traffic_df = pd.DataFrame(traffic_data)
traffic_df['ds'] = pd.to_datetime(traffic_df['date'])
daily_sales = daily_sales.merge(traffic_df[['ds', 'traffic_volume']], on='ds', how='left')
# Fill missing values
daily_sales['temperature'] = daily_sales['temperature'].fillna(daily_sales['temperature'].mean())
daily_sales['humidity'] = daily_sales['humidity'].fillna(daily_sales['humidity'].mean())
daily_sales['precipitation'] = daily_sales['precipitation'].fillna(0)
daily_sales['traffic_volume'] = daily_sales['traffic_volume'].fillna(daily_sales['traffic_volume'].mean())
return daily_sales
async def validate_models(self, models_result: Dict[str, Any], db) -> Dict[str, Any]:
"""Validate trained models"""
validation_results = {}
for product_name, model_data in models_result.items():
try:
# Load model
model_path = model_data.get("path")
model = joblib.load(model_path)
# Mock validation for now (in production, you'd use actual validation data)
validation_results[product_name] = {
"mape": np.random.uniform(10, 25), # Mock MAPE between 10-25%
"rmse": np.random.uniform(8, 15), # Mock RMSE
"mae": np.random.uniform(5, 12), # Mock MAE
"r2_score": np.random.uniform(0.7, 0.9) # Mock R2 score
}
except Exception as e:
logger.error(f"Validation failed for {product_name}: {e}")
validation_results[product_name] = {
"mape": None,
"rmse": None,
"mae": None,
"r2_score": None,
"error": str(e)
}
return validation_results