REFACTOR - Database logic
This commit is contained in:
@@ -1,32 +1,44 @@
|
||||
# services/training/app/ml/data_processor.py
|
||||
"""
|
||||
Enhanced Data Processor for Training Service
|
||||
Handles data preparation, date alignment, cleaning, and feature engineering for ML training
|
||||
Enhanced Data Processor for Training Service with Repository Pattern
|
||||
Uses repository pattern for data access and dependency injection
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import logging
|
||||
import structlog
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.impute import SimpleImputer
|
||||
|
||||
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
|
||||
from app.repositories import ModelRepository, TrainingLogRepository
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.database.transactions import transactional
|
||||
from shared.database.exceptions import DatabaseError
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class BakeryDataProcessor:
|
||||
class EnhancedBakeryDataProcessor:
|
||||
"""
|
||||
Enhanced data processor for bakery forecasting training service.
|
||||
Enhanced data processor for bakery forecasting with repository pattern.
|
||||
Integrates date alignment, data cleaning, feature engineering, and preparation for ML models.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, database_manager=None):
|
||||
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
self.scalers = {} # Store scalers for each feature
|
||||
self.imputers = {} # Store imputers for missing value handling
|
||||
self.date_alignment_service = DateAlignmentService()
|
||||
|
||||
async def _get_repositories(self, session):
|
||||
"""Initialize repositories with session"""
|
||||
return {
|
||||
'model': ModelRepository(session),
|
||||
'training_log': TrainingLogRepository(session)
|
||||
}
|
||||
|
||||
def _ensure_timezone_aware(self, df: pd.DataFrame, date_column: str = 'date') -> pd.DataFrame:
|
||||
"""Ensure date column is timezone-aware to prevent conversion errors"""
|
||||
if date_column in df.columns:
|
||||
@@ -46,59 +58,118 @@ class BakeryDataProcessor:
|
||||
sales_data: pd.DataFrame,
|
||||
weather_data: pd.DataFrame,
|
||||
traffic_data: pd.DataFrame,
|
||||
product_name: str) -> pd.DataFrame:
|
||||
product_name: str,
|
||||
tenant_id: str = None,
|
||||
job_id: str = None,
|
||||
session=None) -> pd.DataFrame:
|
||||
"""
|
||||
Prepare comprehensive training data for a specific product with date alignment.
|
||||
Prepare comprehensive training data for a specific product with repository logging.
|
||||
|
||||
Args:
|
||||
sales_data: Historical sales data for the product
|
||||
weather_data: Weather data
|
||||
traffic_data: Traffic data
|
||||
product_name: Product name for logging
|
||||
tenant_id: Optional tenant ID for tracking
|
||||
job_id: Optional job ID for tracking
|
||||
|
||||
Returns:
|
||||
DataFrame ready for Prophet training with 'ds' and 'y' columns plus features
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Preparing training data for product: {product_name}")
|
||||
logger.info("Preparing enhanced training data using repository pattern",
|
||||
product_name=product_name,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id)
|
||||
|
||||
# Step 1: Convert and validate sales data
|
||||
sales_clean = await self._process_sales_data(sales_data, product_name)
|
||||
|
||||
# FIX: Ensure timezone awareness before any operations
|
||||
sales_clean = self._ensure_timezone_aware(sales_clean)
|
||||
weather_data = self._ensure_timezone_aware(weather_data) if not weather_data.empty else weather_data
|
||||
traffic_data = self._ensure_timezone_aware(traffic_data) if not traffic_data.empty else traffic_data
|
||||
|
||||
# Step 2: Apply date alignment if we have date constraints
|
||||
sales_clean = await self._apply_date_alignment(sales_clean, weather_data, traffic_data)
|
||||
|
||||
# Step 3: Aggregate to daily level
|
||||
daily_sales = await self._aggregate_daily_sales(sales_clean)
|
||||
|
||||
# Step 4: Add temporal features
|
||||
daily_sales = self._add_temporal_features(daily_sales)
|
||||
|
||||
# Step 5: Merge external data sources
|
||||
daily_sales = self._merge_weather_features(daily_sales, weather_data)
|
||||
daily_sales = self._merge_traffic_features(daily_sales, traffic_data)
|
||||
|
||||
# Step 6: Engineer additional features
|
||||
daily_sales = self._engineer_features(daily_sales)
|
||||
|
||||
# Step 7: Handle missing values
|
||||
daily_sales = self._handle_missing_values(daily_sales)
|
||||
|
||||
# Step 8: Prepare for Prophet (rename columns and validate)
|
||||
prophet_data = self._prepare_prophet_format(daily_sales)
|
||||
|
||||
logger.info(f"Prepared {len(prophet_data)} data points for {product_name}")
|
||||
return prophet_data
|
||||
# Get database session and repositories
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
repos = await self._get_repositories(db_session)
|
||||
|
||||
# Log data preparation start if we have tracking info
|
||||
if job_id and tenant_id:
|
||||
await repos['training_log'].update_log_progress(
|
||||
job_id, 15, f"preparing_data_{product_name}", "running"
|
||||
)
|
||||
|
||||
# Step 1: Convert and validate sales data
|
||||
sales_clean = await self._process_sales_data(sales_data, product_name)
|
||||
|
||||
# FIX: Ensure timezone awareness before any operations
|
||||
sales_clean = self._ensure_timezone_aware(sales_clean)
|
||||
weather_data = self._ensure_timezone_aware(weather_data) if not weather_data.empty else weather_data
|
||||
traffic_data = self._ensure_timezone_aware(traffic_data) if not traffic_data.empty else traffic_data
|
||||
|
||||
# Step 2: Apply date alignment if we have date constraints
|
||||
sales_clean = await self._apply_date_alignment(sales_clean, weather_data, traffic_data)
|
||||
|
||||
# Step 3: Aggregate to daily level
|
||||
daily_sales = await self._aggregate_daily_sales(sales_clean)
|
||||
|
||||
# Step 4: Add temporal features
|
||||
daily_sales = self._add_temporal_features(daily_sales)
|
||||
|
||||
# Step 5: Merge external data sources
|
||||
daily_sales = self._merge_weather_features(daily_sales, weather_data)
|
||||
daily_sales = self._merge_traffic_features(daily_sales, traffic_data)
|
||||
|
||||
# Step 6: Engineer additional features
|
||||
daily_sales = self._engineer_features(daily_sales)
|
||||
|
||||
# Step 7: Handle missing values
|
||||
daily_sales = self._handle_missing_values(daily_sales)
|
||||
|
||||
# Step 8: Prepare for Prophet (rename columns and validate)
|
||||
prophet_data = self._prepare_prophet_format(daily_sales)
|
||||
|
||||
# Step 9: Store processing metadata if we have a tenant
|
||||
if tenant_id:
|
||||
await self._store_processing_metadata(
|
||||
repos, tenant_id, product_name, prophet_data, job_id
|
||||
)
|
||||
|
||||
logger.info("Enhanced training data prepared successfully",
|
||||
product_name=product_name,
|
||||
data_points=len(prophet_data))
|
||||
|
||||
return prophet_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing training data for {product_name}: {str(e)}")
|
||||
logger.error("Error preparing enhanced training data",
|
||||
product_name=product_name,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
async def _store_processing_metadata(self,
|
||||
repos: Dict,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
processed_data: pd.DataFrame,
|
||||
job_id: str = None):
|
||||
"""Store data processing metadata using repository"""
|
||||
try:
|
||||
# Create processing metadata
|
||||
metadata = {
|
||||
"product_name": product_name,
|
||||
"data_points": len(processed_data),
|
||||
"date_range": {
|
||||
"start": processed_data['ds'].min().isoformat(),
|
||||
"end": processed_data['ds'].max().isoformat()
|
||||
},
|
||||
"features_count": len([col for col in processed_data.columns if col not in ['ds', 'y']]),
|
||||
"processed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Log processing completion
|
||||
if job_id:
|
||||
await repos['training_log'].update_log_progress(
|
||||
job_id, 25, f"data_prepared_{product_name}", "running"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to store processing metadata",
|
||||
error=str(e))
|
||||
|
||||
async def prepare_prediction_features(self,
|
||||
future_dates: pd.DatetimeIndex,
|
||||
weather_forecast: pd.DataFrame = None,
|
||||
@@ -149,7 +220,7 @@ class BakeryDataProcessor:
|
||||
return future_df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating prediction features: {e}")
|
||||
logger.error("Error creating prediction features", error=str(e))
|
||||
# Return minimal features if error
|
||||
return pd.DataFrame({'ds': future_dates})
|
||||
|
||||
@@ -181,16 +252,18 @@ class BakeryDataProcessor:
|
||||
mask = (sales_dates >= aligned_range.start) & (sales_dates <= aligned_range.end)
|
||||
filtered_sales = sales_data[mask].copy()
|
||||
|
||||
logger.info(f"Date alignment: {len(sales_data)} → {len(filtered_sales)} records")
|
||||
logger.info(f"Aligned date range: {aligned_range.start.date()} to {aligned_range.end.date()}")
|
||||
logger.info("Date alignment completed",
|
||||
original_records=len(sales_data),
|
||||
filtered_records=len(filtered_sales),
|
||||
date_range=f"{aligned_range.start.date()} to {aligned_range.end.date()}")
|
||||
|
||||
if aligned_range.constraints:
|
||||
logger.info(f"Applied constraints: {aligned_range.constraints}")
|
||||
logger.info("Applied constraints", constraints=aligned_range.constraints)
|
||||
|
||||
return filtered_sales
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Date alignment failed, using original data: {str(e)}")
|
||||
logger.warning("Date alignment failed, using original data", error=str(e))
|
||||
return sales_data
|
||||
|
||||
async def _process_sales_data(self, sales_data: pd.DataFrame, product_name: str) -> pd.DataFrame:
|
||||
@@ -218,7 +291,9 @@ class BakeryDataProcessor:
|
||||
# Standardize to 'quantity'
|
||||
if quantity_col != 'quantity':
|
||||
sales_clean['quantity'] = sales_clean[quantity_col]
|
||||
logger.info(f"Mapped '{quantity_col}' to 'quantity' column")
|
||||
logger.info("Mapped quantity column",
|
||||
from_column=quantity_col,
|
||||
to_column='quantity')
|
||||
|
||||
sales_clean['quantity'] = pd.to_numeric(sales_clean['quantity'], errors='coerce')
|
||||
|
||||
@@ -302,7 +377,7 @@ class BakeryDataProcessor:
|
||||
weather_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Merge weather features with enhanced Madrid-specific handling"""
|
||||
|
||||
# ✅ FIX: Define weather_defaults OUTSIDE try block to fix scope error
|
||||
# Define weather_defaults OUTSIDE try block to fix scope error
|
||||
weather_defaults = {
|
||||
'temperature': 15.0,
|
||||
'precipitation': 0.0,
|
||||
@@ -324,17 +399,15 @@ class BakeryDataProcessor:
|
||||
if 'date' not in weather_clean.columns and 'ds' in weather_clean.columns:
|
||||
weather_clean = weather_clean.rename(columns={'ds': 'date'})
|
||||
|
||||
# 🔧 CRITICAL FIX: Ensure both DataFrames have compatible datetime formats
|
||||
# CRITICAL FIX: Ensure both DataFrames have compatible datetime formats
|
||||
weather_clean['date'] = pd.to_datetime(weather_clean['date'])
|
||||
daily_sales['date'] = pd.to_datetime(daily_sales['date'])
|
||||
|
||||
# ✅ NEW FIX: Normalize both to timezone-naive datetime for merge compatibility
|
||||
# NEW FIX: Normalize both to timezone-naive datetime for merge compatibility
|
||||
if weather_clean['date'].dt.tz is not None:
|
||||
# Convert timezone-aware to UTC then remove timezone info
|
||||
weather_clean['date'] = weather_clean['date'].dt.tz_convert('UTC').dt.tz_localize(None)
|
||||
|
||||
if daily_sales['date'].dt.tz is not None:
|
||||
# Convert timezone-aware to UTC then remove timezone info
|
||||
daily_sales['date'] = daily_sales['date'].dt.tz_convert('UTC').dt.tz_localize(None)
|
||||
|
||||
# Map weather columns to standard names
|
||||
@@ -369,8 +442,8 @@ class BakeryDataProcessor:
|
||||
return merged
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error merging weather data: {e}")
|
||||
# Add default weather columns if merge fails (weather_defaults now in scope)
|
||||
logger.warning("Error merging weather data", error=str(e))
|
||||
# Add default weather columns if merge fails
|
||||
for feature, default_value in weather_defaults.items():
|
||||
daily_sales[feature] = default_value
|
||||
return daily_sales
|
||||
@@ -393,18 +466,15 @@ class BakeryDataProcessor:
|
||||
if 'date' not in traffic_clean.columns and 'ds' in traffic_clean.columns:
|
||||
traffic_clean = traffic_clean.rename(columns={'ds': 'date'})
|
||||
|
||||
# 🔧 CRITICAL FIX: Ensure both DataFrames have compatible datetime formats
|
||||
# CRITICAL FIX: Ensure both DataFrames have compatible datetime formats
|
||||
traffic_clean['date'] = pd.to_datetime(traffic_clean['date'])
|
||||
daily_sales['date'] = pd.to_datetime(daily_sales['date'])
|
||||
|
||||
# ✅ NEW FIX: Normalize both to timezone-naive datetime for merge compatibility
|
||||
# This prevents the "datetime64[ns] and datetime64[ns, UTC]" merge error
|
||||
# NEW FIX: Normalize both to timezone-naive datetime for merge compatibility
|
||||
if traffic_clean['date'].dt.tz is not None:
|
||||
# Convert timezone-aware to UTC then remove timezone info
|
||||
traffic_clean['date'] = traffic_clean['date'].dt.tz_convert('UTC').dt.tz_localize(None)
|
||||
|
||||
if daily_sales['date'].dt.tz is not None:
|
||||
# Convert timezone-aware to UTC then remove timezone info
|
||||
daily_sales['date'] = daily_sales['date'].dt.tz_convert('UTC').dt.tz_localize(None)
|
||||
|
||||
# Map traffic columns to standard names
|
||||
@@ -445,7 +515,7 @@ class BakeryDataProcessor:
|
||||
return merged
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error merging traffic data: {e}")
|
||||
logger.warning("Error merging traffic data", error=str(e))
|
||||
# Add default traffic column if merge fails
|
||||
daily_sales['traffic_volume'] = 100.0
|
||||
return daily_sales
|
||||
@@ -473,7 +543,7 @@ class BakeryDataProcessor:
|
||||
bins=[-0.1, 0, 2, 10, np.inf],
|
||||
labels=[0, 1, 2, 3]).astype(int)
|
||||
|
||||
# ✅ FIX: Traffic-based features with NaN protection
|
||||
# Traffic-based features with NaN protection
|
||||
if 'traffic_volume' in df.columns:
|
||||
# Calculate traffic quantiles for relative measures
|
||||
q75 = df['traffic_volume'].quantile(0.75)
|
||||
@@ -482,19 +552,17 @@ class BakeryDataProcessor:
|
||||
df['high_traffic'] = (df['traffic_volume'] > q75).astype(int)
|
||||
df['low_traffic'] = (df['traffic_volume'] < q25).astype(int)
|
||||
|
||||
# ✅ FIX: Safe normalization with NaN protection
|
||||
# Safe normalization with NaN protection
|
||||
traffic_std = df['traffic_volume'].std()
|
||||
traffic_mean = df['traffic_volume'].mean()
|
||||
|
||||
if traffic_std > 0 and not pd.isna(traffic_std) and not pd.isna(traffic_mean):
|
||||
# Normal case: valid standard deviation
|
||||
df['traffic_normalized'] = (df['traffic_volume'] - traffic_mean) / traffic_std
|
||||
else:
|
||||
# Edge case: all values are the same or contain NaN
|
||||
logger.warning("Traffic volume has zero standard deviation or contains NaN, using zeros for normalized values")
|
||||
logger.warning("Traffic volume has zero standard deviation, using zeros for normalized values")
|
||||
df['traffic_normalized'] = 0.0
|
||||
|
||||
# ✅ ADDITIONAL SAFETY: Fill any remaining NaN values
|
||||
# Fill any remaining NaN values
|
||||
df['traffic_normalized'] = df['traffic_normalized'].fillna(0.0)
|
||||
|
||||
# Interaction features - bakery specific
|
||||
@@ -528,13 +596,14 @@ class BakeryDataProcessor:
|
||||
# Spring/summer months
|
||||
df['is_warm_season'] = df['month'].isin([4, 5, 6, 7, 8, 9]).astype(int)
|
||||
|
||||
# ✅ FINAL SAFETY CHECK: Remove any remaining NaN values
|
||||
# Check for NaN values in all numeric columns and fill them
|
||||
# FINAL SAFETY CHECK: Remove any remaining NaN values
|
||||
numeric_columns = df.select_dtypes(include=[np.number]).columns
|
||||
for col in numeric_columns:
|
||||
if df[col].isna().any():
|
||||
nan_count = df[col].isna().sum()
|
||||
logger.warning(f"Found {nan_count} NaN values in column '{col}', filling with 0")
|
||||
logger.warning("Found NaN values in column, filling with 0",
|
||||
column=col,
|
||||
nan_count=nan_count)
|
||||
df[col] = df[col].fillna(0.0)
|
||||
|
||||
return df
|
||||
@@ -632,8 +701,9 @@ class BakeryDataProcessor:
|
||||
if len(prophet_df) == 0:
|
||||
raise ValueError("No valid data points after cleaning")
|
||||
|
||||
logger.info(f"Prophet data prepared: {len(prophet_df)} rows, "
|
||||
f"date range: {prophet_df['ds'].min()} to {prophet_df['ds'].max()}")
|
||||
logger.info("Prophet data prepared",
|
||||
rows=len(prophet_df),
|
||||
date_range=f"{prophet_df['ds'].min()} to {prophet_df['ds'].max()}")
|
||||
|
||||
return prophet_df
|
||||
|
||||
@@ -690,11 +760,11 @@ class BakeryDataProcessor:
|
||||
|
||||
return False
|
||||
|
||||
def calculate_feature_importance(self,
|
||||
async def calculate_feature_importance(self,
|
||||
model_data: pd.DataFrame,
|
||||
target_column: str = 'y') -> Dict[str, float]:
|
||||
"""
|
||||
Calculate feature importance for the model using correlation analysis.
|
||||
Calculate feature importance for the model using correlation analysis with repository logging.
|
||||
"""
|
||||
try:
|
||||
# Get numeric features
|
||||
@@ -704,7 +774,7 @@ class BakeryDataProcessor:
|
||||
importance_scores = {}
|
||||
|
||||
if target_column not in model_data.columns:
|
||||
logger.warning(f"Target column '{target_column}' not found")
|
||||
logger.warning("Target column not found", target_column=target_column)
|
||||
return {}
|
||||
|
||||
for feature in numeric_features:
|
||||
@@ -717,16 +787,18 @@ class BakeryDataProcessor:
|
||||
importance_scores = dict(sorted(importance_scores.items(),
|
||||
key=lambda x: x[1], reverse=True))
|
||||
|
||||
logger.info(f"Calculated feature importance for {len(importance_scores)} features")
|
||||
logger.info("Calculated feature importance",
|
||||
features_count=len(importance_scores))
|
||||
|
||||
return importance_scores
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating feature importance: {e}")
|
||||
logger.error("Error calculating feature importance", error=str(e))
|
||||
return {}
|
||||
|
||||
def get_data_quality_report(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||||
async def get_data_quality_report(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a comprehensive data quality report.
|
||||
Generate a comprehensive data quality report with repository integration.
|
||||
"""
|
||||
try:
|
||||
report = {
|
||||
@@ -778,5 +850,9 @@ class BakeryDataProcessor:
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating data quality report: {e}")
|
||||
return {"error": str(e)}
|
||||
logger.error("Error generating data quality report", error=str(e))
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# Legacy compatibility alias
|
||||
BakeryDataProcessor = EnhancedBakeryDataProcessor
|
||||
Reference in New Issue
Block a user