1212 lines
57 KiB
Python
1212 lines
57 KiB
Python
# services/forecasting/app/services/prediction_service.py - FIXED SEASON FEATURE
|
|
"""
|
|
Prediction service for loading models and generating predictions
|
|
FIXED: Added missing 'season' feature that matches training service exactly
|
|
"""
|
|
|
|
import structlog
|
|
from typing import Dict, List, Any, Optional
|
|
import asyncio
|
|
import pickle
|
|
import json
|
|
from datetime import datetime, date
|
|
import numpy as np
|
|
import pandas as pd
|
|
import httpx
|
|
from pathlib import Path
|
|
import os
|
|
import joblib
|
|
import io
|
|
|
|
from app.core.config import settings
|
|
from shared.monitoring.metrics import MetricsCollector
|
|
from shared.database.base import create_database_manager
|
|
from shared.clients import get_sales_client
|
|
|
|
logger = structlog.get_logger()
|
|
metrics = MetricsCollector("forecasting-service")
|
|
|
|
class PredictionService:
|
|
"""
|
|
Service for loading ML models and generating predictions with dependency injection
|
|
Interfaces with trained Prophet models from the training service
|
|
"""
|
|
|
|
def __init__(self, database_manager=None):
|
|
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
|
self.model_cache = {}
|
|
self.cache_ttl = 3600 # 1 hour cache
|
|
# Initialize sales client for fetching historical data
|
|
self.sales_client = get_sales_client(settings, "forecasting")
|
|
|
|
async def validate_prediction_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validate prediction request"""
|
|
try:
|
|
required_fields = ["inventory_product_id", "model_id", "features"]
|
|
missing_fields = [field for field in required_fields if field not in request]
|
|
|
|
if missing_fields:
|
|
return {
|
|
"is_valid": False,
|
|
"errors": [f"Missing required fields: {missing_fields}"],
|
|
"validation_passed": False
|
|
}
|
|
|
|
return {
|
|
"is_valid": True,
|
|
"errors": [],
|
|
"validation_passed": True,
|
|
"validated_fields": list(request.keys())
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Validation error", error=str(e))
|
|
return {
|
|
"is_valid": False,
|
|
"errors": [str(e)],
|
|
"validation_passed": False
|
|
}
|
|
|
|
async def predict(self, model_id: str, model_path: str, features: Dict[str, Any],
|
|
confidence_level: float = 0.8) -> Dict[str, float]:
|
|
"""Generate prediction using trained model"""
|
|
|
|
start_time = datetime.now()
|
|
|
|
try:
|
|
logger.info("Generating prediction",
|
|
model_id=model_id,
|
|
features_count=len(features))
|
|
|
|
# Load model
|
|
model = await self._load_model(model_id, model_path)
|
|
|
|
if not model:
|
|
raise ValueError(f"Model {model_id} not found or failed to load")
|
|
|
|
# CRITICAL FIX: Fetch historical sales data and calculate historical features
|
|
# This populates lag, rolling, and trend features for better predictions
|
|
# Using 90 days for better trend analysis and more robust rolling statistics
|
|
if 'tenant_id' in features and 'inventory_product_id' in features and 'date' in features:
|
|
try:
|
|
forecast_date = pd.to_datetime(features['date'])
|
|
historical_sales = await self._fetch_historical_sales(
|
|
tenant_id=features['tenant_id'],
|
|
inventory_product_id=features['inventory_product_id'],
|
|
forecast_date=forecast_date,
|
|
days_back=90 # Changed from 30 to 90 for better historical context
|
|
)
|
|
|
|
# Calculate historical features and merge into features dict
|
|
historical_features = self._calculate_historical_features(
|
|
historical_sales, forecast_date
|
|
)
|
|
features.update(historical_features)
|
|
|
|
logger.info("Historical features enriched",
|
|
lag_1_day=historical_features.get('lag_1_day'),
|
|
rolling_mean_7d=historical_features.get('rolling_mean_7d'))
|
|
except Exception as e:
|
|
logger.warning("Failed to enrich with historical features, using defaults",
|
|
error=str(e))
|
|
# Features dict will use defaults (0.0) from _prepare_prophet_features
|
|
|
|
# CRITICAL FIX: Fetch POI (Point of Interest) features from external service
|
|
# Prophet models trained with POI features REQUIRE them during prediction
|
|
# This prevents "Regressor 'poi_retail_total_count' missing" errors
|
|
if 'tenant_id' in features:
|
|
try:
|
|
from shared.clients.external_client import ExternalServiceClient
|
|
from app.core.config import settings
|
|
|
|
external_client = ExternalServiceClient(settings, "forecasting-service")
|
|
poi_data = await external_client.get_poi_context(features['tenant_id'])
|
|
|
|
if poi_data and 'ml_features' in poi_data:
|
|
# Add all POI ML features to prediction features
|
|
poi_features = poi_data['ml_features']
|
|
features.update(poi_features)
|
|
logger.info("POI features enriched",
|
|
tenant_id=features['tenant_id'],
|
|
poi_feature_count=len(poi_features))
|
|
else:
|
|
logger.warning("No POI data available for tenant, using default POI features",
|
|
tenant_id=features['tenant_id'])
|
|
# Provide default POI features to prevent model errors
|
|
# These match ALL features generated by POI detection service
|
|
# Format: poi_{category}_{feature_name}
|
|
default_poi_features = {}
|
|
|
|
# POI categories from external service POI_CATEGORIES configuration
|
|
# These match the categories in services/external/app/core/poi_config.py
|
|
poi_categories = [
|
|
'schools', 'offices', 'gyms_sports', 'residential', 'tourism',
|
|
'competitors', 'transport_hubs', 'coworking', 'retail'
|
|
]
|
|
|
|
for category in poi_categories:
|
|
default_poi_features.update({
|
|
f'poi_{category}_proximity_score': 0.0,
|
|
f'poi_{category}_weighted_proximity_score': 0.0,
|
|
f'poi_{category}_count_0_100m': 0,
|
|
f'poi_{category}_count_100_300m': 0,
|
|
f'poi_{category}_count_300_500m': 0,
|
|
f'poi_{category}_count_500_1000m': 0,
|
|
f'poi_{category}_total_count': 0,
|
|
f'poi_{category}_distance_to_nearest_m': 9999.0,
|
|
f'poi_{category}_has_within_100m': 0,
|
|
f'poi_{category}_has_within_300m': 0,
|
|
f'poi_{category}_has_within_500m': 0,
|
|
})
|
|
|
|
features.update(default_poi_features)
|
|
logger.info("Using default POI features",
|
|
tenant_id=features['tenant_id'],
|
|
default_feature_count=len(default_poi_features))
|
|
except Exception as e:
|
|
logger.error("Failed to fetch POI features, using defaults",
|
|
error=str(e),
|
|
tenant_id=features.get('tenant_id'))
|
|
# On error, still provide default POI features to prevent prediction failures
|
|
default_poi_features = {}
|
|
|
|
# POI categories from external service POI_CATEGORIES configuration
|
|
# These match the categories in services/external/app/core/poi_config.py
|
|
poi_categories = [
|
|
'schools', 'offices', 'gyms_sports', 'residential', 'tourism',
|
|
'competitors', 'transport_hubs', 'coworking', 'retail'
|
|
]
|
|
|
|
for category in poi_categories:
|
|
default_poi_features.update({
|
|
f'poi_{category}_proximity_score': 0.0,
|
|
f'poi_{category}_weighted_proximity_score': 0.0,
|
|
f'poi_{category}_count_0_100m': 0,
|
|
f'poi_{category}_count_100_300m': 0,
|
|
f'poi_{category}_count_300_500m': 0,
|
|
f'poi_{category}_count_500_1000m': 0,
|
|
f'poi_{category}_total_count': 0,
|
|
f'poi_{category}_distance_to_nearest_m': 9999.0,
|
|
f'poi_{category}_has_within_100m': 0,
|
|
f'poi_{category}_has_within_300m': 0,
|
|
f'poi_{category}_has_within_500m': 0,
|
|
})
|
|
|
|
features.update(default_poi_features)
|
|
|
|
# Prepare features for Prophet model
|
|
prophet_df = self._prepare_prophet_features(features)
|
|
|
|
# CRITICAL FIX: Validate that model's required regressors are present
|
|
# Warn if using default values for features the model was trained with
|
|
if hasattr(model, 'extra_regressors'):
|
|
model_regressors = set(model.extra_regressors.keys()) if model.extra_regressors else set()
|
|
provided_features = set(prophet_df.columns) - {'ds'}
|
|
|
|
# Check for missing regressors
|
|
missing_regressors = model_regressors - provided_features
|
|
|
|
if missing_regressors:
|
|
logger.warning(
|
|
"Model trained with regressors that are missing in prediction",
|
|
model_id=model_id,
|
|
missing_regressors=list(missing_regressors)[:10], # Log first 10
|
|
total_missing=len(missing_regressors)
|
|
)
|
|
|
|
# Check for default-valued critical features
|
|
critical_features = {
|
|
'traffic_volume', 'temperature', 'precipitation',
|
|
'lag_1_day', 'rolling_mean_7d'
|
|
}
|
|
using_defaults = []
|
|
for feature in critical_features:
|
|
if feature in model_regressors:
|
|
value = features.get(feature, 0)
|
|
# Check if using default/fallback values
|
|
if (feature == 'traffic_volume' and value == 100.0) or \
|
|
(feature == 'temperature' and value == 15.0) or \
|
|
(feature in ['lag_1_day', 'rolling_mean_7d'] and value == 0.0):
|
|
using_defaults.append(feature)
|
|
|
|
if using_defaults:
|
|
logger.warning(
|
|
"Using default values for critical model features",
|
|
model_id=model_id,
|
|
features_with_defaults=using_defaults
|
|
)
|
|
|
|
# Generate prediction
|
|
forecast = model.predict(prophet_df)
|
|
|
|
# Extract prediction values
|
|
prediction_value = float(forecast['yhat'].iloc[0])
|
|
lower_bound = float(forecast['yhat_lower'].iloc[0])
|
|
upper_bound = float(forecast['yhat_upper'].iloc[0])
|
|
|
|
# Calculate confidence interval
|
|
confidence_interval = upper_bound - lower_bound
|
|
|
|
# Adjust confidence based on data freshness if historical features were calculated
|
|
adjusted_confidence_level = confidence_level
|
|
data_availability_score = features.get('historical_data_availability_score', 1.0) # Default to 1.0 if not available
|
|
|
|
# Reduce confidence if historical data is significantly old
|
|
if data_availability_score < 0.5:
|
|
# For data availability score < 0.5 (more than 90 days old), reduce confidence
|
|
adjusted_confidence_level = max(0.6, confidence_level * data_availability_score)
|
|
|
|
# Increase confidence interval to reflect uncertainty
|
|
adjustment_factor = 1.0 + (0.5 * (1.0 - data_availability_score)) # Up to 50% wider interval
|
|
adjusted_lower_bound = prediction_value - (prediction_value - lower_bound) * adjustment_factor
|
|
adjusted_upper_bound = prediction_value + (upper_bound - prediction_value) * adjustment_factor
|
|
|
|
logger.info("Adjusted prediction confidence due to stale historical data",
|
|
original_confidence=confidence_level,
|
|
adjusted_confidence=adjusted_confidence_level,
|
|
data_availability_score=data_availability_score,
|
|
original_interval=confidence_interval,
|
|
adjusted_interval=adjusted_upper_bound - adjusted_lower_bound)
|
|
|
|
lower_bound = max(0, adjusted_lower_bound)
|
|
upper_bound = adjusted_upper_bound
|
|
confidence_interval = upper_bound - lower_bound
|
|
|
|
result = {
|
|
"prediction": max(0, prediction_value), # Ensure non-negative
|
|
"lower_bound": max(0, lower_bound),
|
|
"upper_bound": max(0, upper_bound),
|
|
"confidence_interval": confidence_interval,
|
|
"confidence_level": adjusted_confidence_level,
|
|
"data_freshness_score": data_availability_score # Include data freshness in result
|
|
}
|
|
|
|
# Record metrics
|
|
processing_time = (datetime.now() - start_time).total_seconds()
|
|
# Record metrics with proper registration and error handling
|
|
try:
|
|
# Register metrics if not already registered
|
|
if "prediction_processing_time" not in metrics._histograms:
|
|
metrics.register_histogram(
|
|
"prediction_processing_time",
|
|
"Time taken to process predictions",
|
|
labels=['service', 'model_type']
|
|
)
|
|
|
|
if "predictions_served_total" not in metrics._counters:
|
|
try:
|
|
metrics.register_counter(
|
|
"predictions_served_total",
|
|
"Total number of predictions served",
|
|
labels=['service', 'status']
|
|
)
|
|
except Exception as reg_error:
|
|
# Metric might already exist in global registry
|
|
logger.debug("Counter already exists in registry", error=str(reg_error))
|
|
|
|
# Now record the metrics - try with expected labels, fallback if needed
|
|
try:
|
|
metrics.observe_histogram(
|
|
"prediction_processing_time",
|
|
processing_time,
|
|
labels={'service': 'forecasting-service', 'model_type': 'prophet'}
|
|
)
|
|
metrics.increment_counter(
|
|
"predictions_served_total",
|
|
labels={'service': 'forecasting-service', 'status': 'success'}
|
|
)
|
|
except Exception as label_error:
|
|
# If specific labels fail, try without labels to avoid breaking predictions
|
|
logger.warning("Failed to record metrics with labels, trying without", error=str(label_error))
|
|
try:
|
|
metrics.observe_histogram("prediction_processing_time", processing_time)
|
|
metrics.increment_counter("predictions_served_total")
|
|
except Exception as no_label_error:
|
|
logger.warning("Failed to record metrics even without labels", error=str(no_label_error))
|
|
|
|
except Exception as metrics_error:
|
|
# Log metrics error but don't fail the prediction
|
|
logger.warning("Failed to register or record metrics", error=str(metrics_error))
|
|
|
|
logger.info("Prediction generated successfully",
|
|
model_id=model_id,
|
|
prediction=result["prediction"],
|
|
processing_time=processing_time)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error("Error generating prediction",
|
|
error=str(e),
|
|
model_id=model_id)
|
|
# Record error metrics with robust error handling
|
|
try:
|
|
if "prediction_errors_total" not in metrics._counters:
|
|
metrics.register_counter(
|
|
"prediction_errors_total",
|
|
"Total number of prediction errors",
|
|
labels=['service', 'error_type']
|
|
)
|
|
|
|
# Try with labels first, then without if that fails
|
|
try:
|
|
metrics.increment_counter(
|
|
"prediction_errors_total",
|
|
labels={'service': 'forecasting-service', 'error_type': 'prediction_failed'}
|
|
)
|
|
except Exception as label_error:
|
|
logger.debug("Failed to record error metrics with labels", error=str(label_error))
|
|
try:
|
|
metrics.increment_counter("prediction_errors_total")
|
|
except Exception as no_label_error:
|
|
logger.warning("Failed to record error metrics even without labels", error=str(no_label_error))
|
|
except Exception as registration_error:
|
|
logger.warning("Failed to register error metrics", error=str(registration_error))
|
|
raise
|
|
|
|
async def predict_with_weather_forecast(
|
|
self,
|
|
model_id: str,
|
|
model_path: str,
|
|
features: Dict[str, Any],
|
|
tenant_id: str,
|
|
days: int = 7,
|
|
confidence_level: float = 0.8
|
|
) -> List[Dict[str, float]]:
|
|
"""
|
|
Generate predictions enriched with real weather forecast data
|
|
|
|
This method:
|
|
1. Loads the trained ML model
|
|
2. Fetches real weather forecast from external service
|
|
3. Enriches prediction features with actual forecast data
|
|
4. Generates weather-aware predictions
|
|
|
|
Args:
|
|
model_id: ID of the trained model
|
|
model_path: Path to model file
|
|
features: Base features for prediction
|
|
tenant_id: Tenant ID for weather forecast
|
|
days: Number of days to forecast
|
|
confidence_level: Confidence level for predictions
|
|
|
|
Returns:
|
|
List of predictions with weather-aware adjustments
|
|
"""
|
|
from app.services.data_client import data_client
|
|
|
|
start_time = datetime.now()
|
|
|
|
try:
|
|
logger.info("Generating weather-aware predictions",
|
|
model_id=model_id,
|
|
days=days)
|
|
|
|
# Step 1: Load ML model
|
|
model = await self._load_model(model_id, model_path)
|
|
if not model:
|
|
raise ValueError(f"Model {model_id} not found")
|
|
|
|
# Step 2: Fetch real weather forecast
|
|
latitude = features.get('latitude', 40.4168)
|
|
longitude = features.get('longitude', -3.7038)
|
|
|
|
weather_forecast = await data_client.fetch_weather_forecast(
|
|
tenant_id=tenant_id,
|
|
days=days,
|
|
latitude=latitude,
|
|
longitude=longitude
|
|
)
|
|
|
|
logger.info(f"Fetched weather forecast for {len(weather_forecast)} days",
|
|
tenant_id=tenant_id)
|
|
|
|
# Step 3: Generate predictions for each day with weather data
|
|
predictions = []
|
|
|
|
for day_offset in range(days):
|
|
# Get weather for this specific day
|
|
day_weather = weather_forecast[day_offset] if day_offset < len(weather_forecast) else {}
|
|
|
|
# Enrich features with actual weather forecast
|
|
enriched_features = features.copy()
|
|
enriched_features.update({
|
|
'temperature': day_weather.get('temperature', features.get('temperature', 20.0)),
|
|
'precipitation': day_weather.get('precipitation', features.get('precipitation', 0.0)),
|
|
'humidity': day_weather.get('humidity', features.get('humidity', 60.0)),
|
|
'wind_speed': day_weather.get('wind_speed', features.get('wind_speed', 10.0)),
|
|
'pressure': day_weather.get('pressure', features.get('pressure', 1013.0)),
|
|
'weather_description': day_weather.get('description', 'Clear')
|
|
})
|
|
|
|
# CRITICAL FIX: Fetch historical sales data and calculate historical features
|
|
# This populates lag, rolling, and trend features for better predictions
|
|
# Using 90 days for better trend analysis and more robust rolling statistics
|
|
if 'tenant_id' in enriched_features and 'inventory_product_id' in enriched_features and 'date' in enriched_features:
|
|
try:
|
|
forecast_date = pd.to_datetime(enriched_features['date'])
|
|
historical_sales = await self._fetch_historical_sales(
|
|
tenant_id=enriched_features['tenant_id'],
|
|
inventory_product_id=enriched_features['inventory_product_id'],
|
|
forecast_date=forecast_date,
|
|
days_back=90 # Changed from 30 to 90 for better historical context
|
|
)
|
|
|
|
# Calculate historical features and merge into features dict
|
|
historical_features = self._calculate_historical_features(
|
|
historical_sales, forecast_date
|
|
)
|
|
enriched_features.update(historical_features)
|
|
|
|
logger.info("Historical features enriched",
|
|
lag_1_day=historical_features.get('lag_1_day'),
|
|
rolling_mean_7d=historical_features.get('rolling_mean_7d'))
|
|
except Exception as e:
|
|
logger.warning("Failed to enrich with historical features, using defaults",
|
|
error=str(e))
|
|
# Features dict will use defaults (0.0) from _prepare_prophet_features
|
|
|
|
# Prepare Prophet dataframe with weather features
|
|
prophet_df = self._prepare_prophet_features(enriched_features)
|
|
|
|
# Generate prediction for this day
|
|
forecast = model.predict(prophet_df)
|
|
|
|
prediction_value = float(forecast['yhat'].iloc[0])
|
|
lower_bound = float(forecast['yhat_lower'].iloc[0])
|
|
upper_bound = float(forecast['yhat_upper'].iloc[0])
|
|
|
|
# Calculate confidence adjustment based on data freshness
|
|
current_confidence_level = confidence_level
|
|
data_availability_score = enriched_features.get('historical_data_availability_score', 1.0) # Default to 1.0 if not available
|
|
|
|
# Adjust confidence based on data freshness if historical features were calculated
|
|
# Reduce confidence if historical data is significantly old
|
|
if data_availability_score < 0.5:
|
|
# For data availability score < 0.5 (more than 90 days old), reduce confidence
|
|
current_confidence_level = max(0.6, confidence_level * data_availability_score)
|
|
|
|
# Increase confidence interval to reflect uncertainty
|
|
adjustment_factor = 1.0 + (0.5 * (1.0 - data_availability_score)) # Up to 50% wider interval
|
|
adjusted_lower_bound = prediction_value - (prediction_value - lower_bound) * adjustment_factor
|
|
adjusted_upper_bound = prediction_value + (upper_bound - prediction_value) * adjustment_factor
|
|
|
|
logger.info("Adjusted weather prediction confidence due to stale historical data",
|
|
original_confidence=confidence_level,
|
|
adjusted_confidence=current_confidence_level,
|
|
data_availability_score=data_availability_score)
|
|
|
|
lower_bound = max(0, adjusted_lower_bound)
|
|
upper_bound = adjusted_upper_bound
|
|
|
|
# Apply weather-based adjustments (business rules)
|
|
adjusted_prediction = self._apply_weather_adjustments(
|
|
prediction_value,
|
|
day_weather,
|
|
features.get('product_category', 'general')
|
|
)
|
|
|
|
predictions.append({
|
|
"date": enriched_features['date'],
|
|
"prediction": max(0, adjusted_prediction),
|
|
"lower_bound": max(0, lower_bound),
|
|
"upper_bound": max(0, upper_bound),
|
|
"confidence_level": current_confidence_level,
|
|
"data_freshness_score": data_availability_score, # Include data freshness in result
|
|
"weather": {
|
|
"temperature": enriched_features['temperature'],
|
|
"precipitation": enriched_features['precipitation'],
|
|
"description": enriched_features['weather_description']
|
|
}
|
|
})
|
|
|
|
processing_time = (datetime.now() - start_time).total_seconds()
|
|
|
|
logger.info("Weather-aware predictions generated",
|
|
model_id=model_id,
|
|
days=len(predictions),
|
|
processing_time=processing_time)
|
|
|
|
return predictions
|
|
|
|
except Exception as e:
|
|
logger.error("Error generating weather-aware predictions",
|
|
error=str(e),
|
|
model_id=model_id)
|
|
raise
|
|
|
|
def _apply_weather_adjustments(
|
|
self,
|
|
base_prediction: float,
|
|
weather: Dict[str, Any],
|
|
product_category: str
|
|
) -> float:
|
|
"""
|
|
Apply business rules based on weather conditions
|
|
|
|
Adjusts predictions based on real weather forecast
|
|
"""
|
|
adjusted = base_prediction
|
|
temp = weather.get('temperature', 20.0)
|
|
precip = weather.get('precipitation', 0.0)
|
|
|
|
# Temperature-based adjustments
|
|
if product_category == 'ice_cream':
|
|
if temp > 30:
|
|
adjusted *= 1.4 # +40% for very hot days
|
|
elif temp > 25:
|
|
adjusted *= 1.2 # +20% for hot days
|
|
elif temp < 15:
|
|
adjusted *= 0.7 # -30% for cold days
|
|
|
|
elif product_category == 'bread':
|
|
if temp > 30:
|
|
adjusted *= 0.9 # -10% for very hot days
|
|
elif temp < 10:
|
|
adjusted *= 1.1 # +10% for cold days
|
|
|
|
elif product_category == 'coffee':
|
|
if temp < 15:
|
|
adjusted *= 1.2 # +20% for cold days
|
|
elif precip > 5:
|
|
adjusted *= 1.15 # +15% for rainy days
|
|
|
|
# Precipitation-based adjustments
|
|
if precip > 10: # Heavy rain
|
|
if product_category in ['pastry', 'coffee']:
|
|
adjusted *= 1.2 # People stay indoors, buy comfort food
|
|
|
|
return adjusted
|
|
|
|
async def _load_model(self, model_id: str, model_path: str):
|
|
"""Load model from MinIO with improved validation and error handling"""
|
|
|
|
# Check cache first
|
|
if model_id in self.model_cache:
|
|
cached_model, cached_time = self.model_cache[model_id]
|
|
if (datetime.now() - cached_time).seconds < self.cache_ttl:
|
|
logger.debug(f"Model loaded from cache: {model_id}")
|
|
return cached_model
|
|
|
|
# Validate MinIO path format
|
|
if not await self._validate_model_file(model_path):
|
|
logger.error(f"Model path not valid: {model_path}")
|
|
return None
|
|
|
|
try:
|
|
# Load from MinIO
|
|
model = await self._load_model_safely(model_path)
|
|
|
|
if model is None:
|
|
logger.error(f"Failed to load model from MinIO: {model_path}")
|
|
return None
|
|
|
|
# Cache the model
|
|
self.model_cache[model_id] = (model, datetime.now())
|
|
logger.info(f"Model loaded successfully from MinIO: {model_path}")
|
|
return model
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading model from MinIO: {e}")
|
|
return None
|
|
|
|
async def _load_model_safely(self, model_path: str):
|
|
"""Load model from MinIO storage (clean implementation - MinIO only)"""
|
|
try:
|
|
# Parse MinIO path: minio://bucket_name/object_path
|
|
_, bucket_and_path = model_path.split("://", 1)
|
|
bucket_name, object_name = bucket_and_path.split("/", 1)
|
|
|
|
logger.debug(f"Loading model from MinIO: {bucket_name}/{object_name}")
|
|
|
|
# Use MinIO client
|
|
from shared.clients.minio_client import minio_client
|
|
|
|
# Download model data
|
|
model_data = minio_client.get_object(bucket_name, object_name)
|
|
if not model_data:
|
|
logger.error(f"Failed to download model from MinIO: {model_path}")
|
|
return None
|
|
|
|
# Try joblib first (using BytesIO since joblib.load reads from file-like objects)
|
|
try:
|
|
buffer = io.BytesIO(model_data)
|
|
model = joblib.load(buffer)
|
|
logger.info(f"Model loaded successfully from MinIO with joblib")
|
|
return model
|
|
except Exception as e:
|
|
logger.warning(f"Joblib loading from MinIO failed: {e}")
|
|
|
|
# Try pickle as fallback
|
|
try:
|
|
model = pickle.loads(model_data)
|
|
logger.info(f"Model loaded successfully from MinIO with pickle")
|
|
return model
|
|
except Exception as e:
|
|
logger.warning(f"Pickle loading from MinIO failed: {e}")
|
|
|
|
logger.error(f"All loading methods failed for MinIO object: {model_path}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model from MinIO: {model_path}, error: {e}")
|
|
return None
|
|
|
|
async def _validate_model_file(self, model_path: str) -> bool:
|
|
"""Validate MinIO model path and check object exists"""
|
|
try:
|
|
# Validate MinIO path format
|
|
if not model_path.startswith("minio://"):
|
|
logger.error(f"Invalid model path format (expected minio://): {model_path}")
|
|
return False
|
|
|
|
# Parse MinIO path
|
|
try:
|
|
_, bucket_and_path = model_path.split("://", 1)
|
|
bucket_name, object_name = bucket_and_path.split("/", 1)
|
|
except ValueError:
|
|
logger.error(f"Cannot parse MinIO path: {model_path}")
|
|
return False
|
|
|
|
# Check if object exists in MinIO
|
|
from shared.clients.minio_client import minio_client
|
|
|
|
if not minio_client.object_exists(bucket_name, object_name):
|
|
logger.error(f"Model object not found in MinIO: {bucket_name}/{object_name}")
|
|
return False
|
|
|
|
# Check object metadata for size validation
|
|
metadata = minio_client.get_object_metadata(bucket_name, object_name)
|
|
if metadata:
|
|
file_size = metadata.get("size", 0)
|
|
if file_size < 1024:
|
|
logger.warning(f"Model object too small ({file_size} bytes): {model_path}")
|
|
return False
|
|
|
|
logger.debug(f"Model validated in MinIO: {bucket_name}/{object_name}, size={file_size}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Model validation error: {e}")
|
|
return False
|
|
|
|
async def _fetch_historical_sales(
|
|
self,
|
|
tenant_id: str,
|
|
inventory_product_id: str,
|
|
forecast_date: datetime,
|
|
days_back: int = 90
|
|
) -> pd.Series:
|
|
"""
|
|
Fetch historical sales data for calculating lagged and rolling features.
|
|
Enhanced to handle cases where recent data is not available by extending
|
|
the search for the most recent data if needed.
|
|
|
|
Args:
|
|
tenant_id: Tenant UUID
|
|
inventory_product_id: Product UUID
|
|
forecast_date: The date we're forecasting for
|
|
days_back: Number of days of history to fetch (default 90 for better trend analysis)
|
|
|
|
Returns:
|
|
pandas Series with sales quantities indexed by date
|
|
"""
|
|
try:
|
|
# Calculate initial date range for recent data
|
|
end_date = forecast_date - pd.Timedelta(days=1) # Day before forecast
|
|
start_date = end_date - pd.Timedelta(days=days_back)
|
|
|
|
logger.debug("Fetching historical sales for feature calculation",
|
|
tenant_id=tenant_id,
|
|
product_id=inventory_product_id,
|
|
start_date=start_date.date(),
|
|
end_date=end_date.date(),
|
|
days_back=days_back)
|
|
|
|
# First, try to fetch sales data from the recent period
|
|
sales_data = await self.sales_client.get_sales_data(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date.strftime("%Y-%m-%d"),
|
|
end_date=end_date.strftime("%Y-%m-%d"),
|
|
product_id=inventory_product_id,
|
|
aggregation="daily"
|
|
)
|
|
|
|
# If no recent data found, search for the most recent available data
|
|
if not sales_data:
|
|
logger.info("No recent sales data found, expanding search to find most recent data",
|
|
tenant_id=tenant_id,
|
|
product_id=inventory_product_id)
|
|
|
|
# Search for available data in larger time windows (up to 2 years back)
|
|
search_windows = [365, 730] # 1 year, 2 years
|
|
|
|
for window_days in search_windows:
|
|
extended_start_date = forecast_date - pd.Timedelta(days=window_days)
|
|
|
|
logger.debug("Expanding search window for historical data",
|
|
start_date=extended_start_date.date(),
|
|
end_date=end_date.date(),
|
|
window_days=window_days)
|
|
|
|
sales_data = await self.sales_client.get_sales_data(
|
|
tenant_id=tenant_id,
|
|
start_date=extended_start_date.strftime("%Y-%m-%d"),
|
|
end_date=end_date.strftime("%Y-%m-%d"),
|
|
product_id=inventory_product_id,
|
|
aggregation="daily"
|
|
)
|
|
|
|
if sales_data:
|
|
logger.info("Found historical data in expanded search window",
|
|
tenant_id=tenant_id,
|
|
product_id=inventory_product_id,
|
|
data_start=sales_data[0]['sale_date'] if sales_data else "None",
|
|
data_end=sales_data[-1]['sale_date'] if sales_data else "None",
|
|
window_days=window_days)
|
|
break
|
|
|
|
if not sales_data:
|
|
logger.warning("No historical sales data found in any search window",
|
|
tenant_id=tenant_id,
|
|
product_id=inventory_product_id)
|
|
return pd.Series(dtype=float)
|
|
|
|
# Convert to pandas DataFrame and check if it has the expected structure
|
|
df = pd.DataFrame(sales_data)
|
|
|
|
# Check if the expected 'sale_date' column exists
|
|
if df.empty:
|
|
logger.warning("No historical sales data returned from API")
|
|
return pd.Series(dtype=float)
|
|
|
|
# Check for available columns and find date column
|
|
available_columns = list(df.columns)
|
|
logger.debug(f"Available sales data columns: {available_columns}")
|
|
|
|
# Check for alternative date column names
|
|
date_columns = ['sale_date', 'date', 'forecast_date', 'datetime', 'timestamp']
|
|
date_column = None
|
|
for col in date_columns:
|
|
if col in df.columns:
|
|
date_column = col
|
|
break
|
|
|
|
if date_column is None:
|
|
logger.error(f"Sales data missing expected date column. Available columns: {available_columns}")
|
|
logger.debug(f"Sample of sales data: {df.head()}")
|
|
return pd.Series(dtype=float)
|
|
|
|
df['sale_date'] = pd.to_datetime(df[date_column])
|
|
df = df.set_index('sale_date')
|
|
|
|
# Extract quantity column (could be 'quantity' or 'total_quantity')
|
|
if 'quantity' in df.columns:
|
|
series = df['quantity']
|
|
elif 'total_quantity' in df.columns:
|
|
series = df['total_quantity']
|
|
else:
|
|
logger.warning("Sales data missing quantity field",
|
|
columns=list(df.columns))
|
|
return pd.Series(dtype=float)
|
|
|
|
logger.debug("Historical sales fetched successfully",
|
|
records=len(series),
|
|
date_range=f"{series.index.min()} to {series.index.max()}")
|
|
|
|
return series.sort_index()
|
|
|
|
except Exception as e:
|
|
logger.error("Error fetching historical sales",
|
|
error=str(e),
|
|
tenant_id=tenant_id,
|
|
product_id=inventory_product_id)
|
|
return pd.Series(dtype=float)
|
|
|
|
def _calculate_historical_features(
|
|
self,
|
|
historical_sales: pd.Series,
|
|
forecast_date: datetime
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Calculate lagged, rolling, and trend features from historical sales data.
|
|
Enhanced to handle cases where recent data is not available by using
|
|
available historical data with appropriate temporal adjustments.
|
|
|
|
Now uses shared feature calculator for consistency with training service.
|
|
|
|
Args:
|
|
historical_sales: Series of sales quantities indexed by date
|
|
forecast_date: The date we're forecasting for
|
|
|
|
Returns:
|
|
Dictionary of calculated features
|
|
"""
|
|
try:
|
|
# Use shared feature calculator for consistency
|
|
from shared.ml.feature_calculator import HistoricalFeatureCalculator
|
|
|
|
calculator = HistoricalFeatureCalculator()
|
|
|
|
# Calculate all features using shared calculator
|
|
features = calculator.calculate_all_features(
|
|
sales_data=historical_sales,
|
|
reference_date=forecast_date,
|
|
mode='prediction'
|
|
)
|
|
|
|
logger.debug("Historical features calculated (using shared calculator)",
|
|
lag_1_day=features.get('lag_1_day', 0.0),
|
|
rolling_mean_7d=features.get('rolling_mean_7d', 0.0),
|
|
rolling_mean_30d=features.get('rolling_mean_30d', 0.0),
|
|
momentum=features.get('momentum_1_7', 0.0),
|
|
days_since_last_sale=features.get('days_since_last_sale', 0),
|
|
data_availability_score=features.get('historical_data_availability_score', 0.0))
|
|
|
|
return features
|
|
|
|
except Exception as e:
|
|
logger.error("Error calculating historical features",
|
|
error=str(e))
|
|
# Return default values on error
|
|
return {k: 0.0 for k in [
|
|
'lag_1_day', 'lag_7_day', 'lag_14_day',
|
|
'rolling_mean_7d', 'rolling_std_7d', 'rolling_max_7d', 'rolling_min_7d',
|
|
'rolling_mean_14d', 'rolling_std_14d', 'rolling_max_14d', 'rolling_min_14d',
|
|
'rolling_mean_30d', 'rolling_std_30d', 'rolling_max_30d', 'rolling_min_30d',
|
|
'momentum_1_7', 'trend_7_30', 'velocity_week',
|
|
'days_since_last_sale', 'historical_data_availability_score'
|
|
]}
|
|
|
|
def _prepare_prophet_features(self, features: Dict[str, Any]) -> pd.DataFrame:
|
|
"""Convert features to Prophet-compatible DataFrame - COMPLETE FEATURE MATCHING"""
|
|
|
|
try:
|
|
# Create base DataFrame with required 'ds' column
|
|
df = pd.DataFrame({
|
|
'ds': [pd.to_datetime(features['date'])]
|
|
})
|
|
|
|
# ✅ FIX: Add ALL traffic features that training service uses
|
|
# Core traffic features
|
|
df['traffic_volume'] = float(features.get('traffic_volume', 100.0))
|
|
df['pedestrian_count'] = float(features.get('pedestrian_count', 50.0))
|
|
df['congestion_level'] = float(features.get('congestion_level', 1.0))
|
|
df['average_speed'] = float(features.get('average_speed', 30.0)) # ← MISSING FEATURE!
|
|
|
|
# Weather features
|
|
df['temperature'] = float(features.get('temperature', 15.0))
|
|
df['precipitation'] = float(features.get('precipitation', 0.0))
|
|
df['humidity'] = float(features.get('humidity', 60.0))
|
|
df['wind_speed'] = float(features.get('wind_speed', 5.0))
|
|
df['pressure'] = float(features.get('pressure', 1013.0))
|
|
df['temp_category'] = self._get_temp_category(df['temperature'].iloc[0])
|
|
|
|
# Extract date information for temporal features
|
|
forecast_date = pd.to_datetime(features['date'])
|
|
day_of_week = forecast_date.weekday() # 0=Monday, 6=Sunday
|
|
|
|
# ✅ FIX: Add ALL temporal features (must match training exactly!)
|
|
df['day_of_week'] = int(day_of_week)
|
|
df['day_of_month'] = int(forecast_date.day)
|
|
df['month'] = int(forecast_date.month)
|
|
df['quarter'] = int(forecast_date.quarter)
|
|
df['week_of_year'] = int(forecast_date.isocalendar().week)
|
|
|
|
# ✅ FIX: Add the missing 'season' feature that matches training exactly
|
|
df['season'] = self._get_season(forecast_date.month)
|
|
|
|
# Bakery-specific temporal features
|
|
df['is_weekend'] = int(day_of_week >= 5)
|
|
df['is_monday'] = int(day_of_week == 0)
|
|
df['is_tuesday'] = int(day_of_week == 1)
|
|
df['is_wednesday'] = int(day_of_week == 2)
|
|
df['is_thursday'] = int(day_of_week == 3)
|
|
df['is_friday'] = int(day_of_week == 4)
|
|
df['is_saturday'] = int(day_of_week == 5)
|
|
df['is_sunday'] = int(day_of_week == 6)
|
|
df['is_working_day'] = int(day_of_week < 5) # Working days (Mon-Fri)
|
|
|
|
# Season-based features (match training service)
|
|
df['is_spring'] = int(df['season'].iloc[0] == 2)
|
|
df['is_summer'] = int(df['season'].iloc[0] == 3)
|
|
df['is_autumn'] = int(df['season'].iloc[0] == 4)
|
|
df['is_winter'] = int(df['season'].iloc[0] == 1)
|
|
|
|
# ✅ PERFORMANCE FIX: Build all features at once to avoid DataFrame fragmentation
|
|
|
|
# Extract values once to avoid repeated iloc calls
|
|
temperature = df['temperature'].iloc[0]
|
|
humidity = df['humidity'].iloc[0]
|
|
pressure = df['pressure'].iloc[0]
|
|
wind_speed = df['wind_speed'].iloc[0]
|
|
precipitation = df['precipitation'].iloc[0]
|
|
traffic = df['traffic_volume'].iloc[0]
|
|
pedestrians = df['pedestrian_count'].iloc[0]
|
|
avg_speed = df['average_speed'].iloc[0]
|
|
congestion = df['congestion_level'].iloc[0]
|
|
season = df['season'].iloc[0]
|
|
is_weekend = df['is_weekend'].iloc[0]
|
|
|
|
# Build all new features as a dictionary
|
|
new_features = {
|
|
# Holiday features
|
|
'is_holiday': int(features.get('is_holiday', False)),
|
|
'is_school_holiday': int(features.get('is_school_holiday', False)),
|
|
|
|
# Month-based features
|
|
'is_january': int(forecast_date.month == 1),
|
|
'is_february': int(forecast_date.month == 2),
|
|
'is_march': int(forecast_date.month == 3),
|
|
'is_april': int(forecast_date.month == 4),
|
|
'is_may': int(forecast_date.month == 5),
|
|
'is_june': int(forecast_date.month == 6),
|
|
'is_july': int(forecast_date.month == 7),
|
|
'is_august': int(forecast_date.month == 8),
|
|
'is_september': int(forecast_date.month == 9),
|
|
'is_october': int(forecast_date.month == 10),
|
|
'is_november': int(forecast_date.month == 11),
|
|
'is_december': int(forecast_date.month == 12),
|
|
|
|
# Special day features
|
|
'is_month_start': int(forecast_date.day <= 3),
|
|
'is_month_end': int(forecast_date.day >= 28),
|
|
'is_payday_period': int((forecast_date.day <= 5) or (forecast_date.day >= 25)),
|
|
# CRITICAL FIX: Add is_payday feature to match training service
|
|
# Training defines: is_payday = (day == 15 OR day == 28 OR is_month_end)
|
|
# Spain commonly pays on 28th, 15th, or last day of month
|
|
'is_payday': int((forecast_date.day == 15) or (forecast_date.day == 28) or self._is_end_of_month(forecast_date)),
|
|
|
|
# Weather-based derived features
|
|
'temp_squared': temperature ** 2,
|
|
'is_cold_day': int(temperature < 10),
|
|
'is_hot_day': int(temperature > 25),
|
|
'is_pleasant_day': int(10 <= temperature <= 25),
|
|
|
|
# Humidity features
|
|
'humidity_squared': humidity ** 2,
|
|
'is_high_humidity': int(humidity > 70),
|
|
'is_low_humidity': int(humidity < 40),
|
|
|
|
# Pressure features
|
|
'pressure_squared': pressure ** 2,
|
|
'is_high_pressure': int(pressure > 1020),
|
|
'is_low_pressure': int(pressure < 1000),
|
|
|
|
# Wind features
|
|
'wind_squared': wind_speed ** 2,
|
|
'is_windy': int(wind_speed > 15),
|
|
'is_calm': int(wind_speed < 5),
|
|
|
|
# Precipitation features
|
|
'precip_squared': precipitation ** 2,
|
|
'precip_log': float(np.log1p(precipitation)),
|
|
'is_rainy_day': int(precipitation > 0.1),
|
|
'is_very_rainy_day': int(precipitation > 5.0),
|
|
'is_heavy_rain': int(precipitation > 10),
|
|
'rain_intensity': self._get_rain_intensity(precipitation),
|
|
|
|
# Traffic-based features
|
|
'high_traffic': int(traffic > 150) if traffic > 0 else 0,
|
|
'low_traffic': int(traffic < 50) if traffic > 0 else 0,
|
|
# Fix: Use same normalization as training (when std=0, normalized=0.0)
|
|
# Training uses constant 100.0 values, so std=0 and normalized=0.0
|
|
'traffic_normalized': 0.0, # Match training behavior for consistent predictions
|
|
'traffic_squared': traffic ** 2,
|
|
'traffic_log': float(np.log1p(traffic)),
|
|
|
|
# Pedestrian features
|
|
'high_pedestrian_count': int(pedestrians > 100),
|
|
'low_pedestrian_count': int(pedestrians < 25),
|
|
'pedestrian_normalized': float((pedestrians - 50) / 25),
|
|
'pedestrian_squared': pedestrians ** 2,
|
|
'pedestrian_log': float(np.log1p(pedestrians)),
|
|
|
|
# Speed features
|
|
'high_speed': int(avg_speed > 40),
|
|
'low_speed': int(avg_speed < 20),
|
|
'speed_normalized': float((avg_speed - 30) / 10),
|
|
'speed_squared': avg_speed ** 2,
|
|
'speed_log': float(np.log1p(avg_speed)),
|
|
|
|
# Congestion features
|
|
'high_congestion': int(congestion > 3),
|
|
'low_congestion': int(congestion < 2),
|
|
'congestion_squared': congestion ** 2,
|
|
|
|
# Day features
|
|
'is_peak_bakery_day': int(day_of_week in [4, 5, 6]),
|
|
'is_high_demand_month': int(forecast_date.month in [6, 7, 8, 12]),
|
|
'is_warm_season': int(forecast_date.month in [4, 5, 6, 7, 8, 9]),
|
|
|
|
# CRITICAL FIX: Cyclical encoding features (MATCH TRAINING)
|
|
# These encode day_of_week and month as sin/cos for cyclical patterns
|
|
'day_of_week_sin': float(np.sin(2 * np.pi * day_of_week / 7)),
|
|
'day_of_week_cos': float(np.cos(2 * np.pi * day_of_week / 7)),
|
|
'month_sin': float(np.sin(2 * np.pi * forecast_date.month / 12)),
|
|
'month_cos': float(np.cos(2 * np.pi * forecast_date.month / 12)),
|
|
|
|
# CRITICAL FIX: Historical features (lagged, rolling, trend)
|
|
# These will be populated from historical sales data
|
|
# Default to 0.0 here, will be updated if historical data is provided
|
|
'lag_1_day': float(features.get('lag_1_day', 0.0)),
|
|
'lag_7_day': float(features.get('lag_7_day', 0.0)),
|
|
'lag_14_day': float(features.get('lag_14_day', 0.0)),
|
|
'rolling_mean_7d': float(features.get('rolling_mean_7d', 0.0)),
|
|
'rolling_std_7d': float(features.get('rolling_std_7d', 0.0)),
|
|
'rolling_max_7d': float(features.get('rolling_max_7d', 0.0)),
|
|
'rolling_min_7d': float(features.get('rolling_min_7d', 0.0)),
|
|
'rolling_mean_14d': float(features.get('rolling_mean_14d', 0.0)),
|
|
'rolling_std_14d': float(features.get('rolling_std_14d', 0.0)),
|
|
'rolling_max_14d': float(features.get('rolling_max_14d', 0.0)),
|
|
'rolling_min_14d': float(features.get('rolling_min_14d', 0.0)),
|
|
'rolling_mean_30d': float(features.get('rolling_mean_30d', 0.0)),
|
|
'rolling_std_30d': float(features.get('rolling_std_30d', 0.0)),
|
|
'rolling_max_30d': float(features.get('rolling_max_30d', 0.0)),
|
|
'rolling_min_30d': float(features.get('rolling_min_30d', 0.0)),
|
|
'days_since_start': int(features.get('days_since_start', 0)),
|
|
'momentum_1_7': float(features.get('momentum_1_7', 0.0)),
|
|
'trend_7_30': float(features.get('trend_7_30', 0.0)),
|
|
'velocity_week': float(features.get('velocity_week', 0.0)),
|
|
# Data freshness metrics to help model understand data recency
|
|
'days_since_last_sale': int(features.get('days_since_last_sale', 0)),
|
|
'historical_data_availability_score': float(features.get('historical_data_availability_score', 0.0)),
|
|
}
|
|
|
|
# Calculate interaction features
|
|
is_holiday = new_features['is_holiday']
|
|
is_pleasant = new_features['is_pleasant_day']
|
|
is_rainy = new_features['is_rainy_day']
|
|
is_payday = new_features['is_payday']
|
|
|
|
interaction_features = {
|
|
# Weekend interactions
|
|
'weekend_temp_interaction': is_weekend * temperature,
|
|
'weekend_pleasant_weather': is_weekend * is_pleasant,
|
|
'weekend_traffic_interaction': is_weekend * traffic,
|
|
|
|
# Holiday interactions
|
|
'holiday_temp_interaction': is_holiday * temperature,
|
|
'holiday_traffic_interaction': is_holiday * traffic,
|
|
|
|
# CRITICAL FIX: Add payday_weekend_interaction to match training service
|
|
'payday_weekend_interaction': is_payday * is_weekend,
|
|
|
|
# Season interactions
|
|
'season_temp_interaction': season * temperature,
|
|
'season_traffic_interaction': season * traffic,
|
|
|
|
# Rain-traffic interactions
|
|
'rain_traffic_interaction': is_rainy * traffic,
|
|
'rain_speed_interaction': is_rainy * avg_speed,
|
|
|
|
# CRITICAL FIX: Add missing interaction features from training
|
|
'rain_weekend_interaction': is_rainy * is_weekend,
|
|
'friday_traffic_interaction': int(day_of_week == 4) * traffic,
|
|
|
|
# Day-weather interactions
|
|
'day_temp_interaction': day_of_week * temperature,
|
|
'month_temp_interaction': forecast_date.month * temperature,
|
|
|
|
# Traffic-speed interactions
|
|
'traffic_speed_interaction': traffic * avg_speed,
|
|
'pedestrian_speed_interaction': pedestrians * avg_speed,
|
|
|
|
# Congestion interactions
|
|
'congestion_temp_interaction': congestion * temperature,
|
|
'congestion_weekend_interaction': congestion * is_weekend
|
|
}
|
|
|
|
# CRITICAL FIX: Extract POI (Point of Interest) features from the features dict
|
|
# POI features start with 'poi_' prefix and must be included for models trained with them
|
|
# This prevents "Regressor 'poi_retail_total_count' missing" errors
|
|
poi_features = {}
|
|
for key, value in features.items():
|
|
if key.startswith('poi_'):
|
|
# Ensure POI features are numeric (float or int)
|
|
try:
|
|
poi_features[key] = float(value) if isinstance(value, (int, float, str)) else 0.0
|
|
except (ValueError, TypeError):
|
|
poi_features[key] = 0.0
|
|
|
|
# Combine all features
|
|
all_new_features = {**new_features, **interaction_features, **poi_features}
|
|
|
|
# Add all features at once using pd.concat to avoid fragmentation
|
|
new_feature_df = pd.DataFrame([all_new_features])
|
|
df = pd.concat([df, new_feature_df], axis=1)
|
|
|
|
logger.debug("Complete Prophet features prepared",
|
|
feature_count=len(df.columns),
|
|
date=features['date'],
|
|
season=df['season'].iloc[0],
|
|
traffic_volume=df['traffic_volume'].iloc[0],
|
|
average_speed=df['average_speed'].iloc[0],
|
|
pedestrian_count=df['pedestrian_count'].iloc[0],
|
|
poi_feature_count=len(poi_features))
|
|
|
|
return df
|
|
|
|
except Exception as e:
|
|
logger.error("Error preparing Prophet features", error=str(e))
|
|
raise
|
|
|
|
def _get_season(self, month: int) -> int:
|
|
"""Get season from month (1-4 for Winter, Spring, Summer, Autumn) - MATCH TRAINING"""
|
|
if month in [12, 1, 2]:
|
|
return 1 # Winter
|
|
elif month in [3, 4, 5]:
|
|
return 2 # Spring
|
|
elif month in [6, 7, 8]:
|
|
return 3 # Summer
|
|
else:
|
|
return 4 # Autumn
|
|
|
|
def _is_school_holiday(self, date: datetime) -> bool:
|
|
"""Check if a date is during school holidays - MATCH TRAINING"""
|
|
month = date.month
|
|
|
|
# Approximate Spanish school holiday periods
|
|
if month in [7, 8]: # Summer holidays
|
|
return True
|
|
if month == 12 and date.day >= 20: # Christmas holidays
|
|
return True
|
|
if month == 1 and date.day <= 10: # Christmas holidays continued
|
|
return True
|
|
if month == 4 and date.day <= 15: # Easter holidays (approximate)
|
|
return True
|
|
|
|
return False
|
|
|
|
def _get_temp_category(self, temperature: float) -> int:
|
|
"""Get temperature category (0-3) - MATCH TRAINING"""
|
|
if temperature <= 5:
|
|
return 0 # Very cold
|
|
elif temperature <= 15:
|
|
return 1 # Cold
|
|
elif temperature <= 25:
|
|
return 2 # Mild
|
|
else:
|
|
return 3 # Hot
|
|
|
|
def _get_rain_intensity(self, precipitation: float) -> int:
|
|
"""Get rain intensity category (0-3) - MATCH TRAINING"""
|
|
if precipitation <= 0:
|
|
return 0 # No rain
|
|
elif precipitation <= 2:
|
|
return 1 # Light rain
|
|
elif precipitation <= 10:
|
|
return 2 # Moderate rain
|
|
else:
|
|
return 3 # Heavy rain
|
|
|
|
def _is_end_of_month(self, date: datetime) -> bool:
|
|
"""
|
|
Check if date is the last day of the month - MATCH TRAINING SERVICE
|
|
Training uses: df[date_column].dt.is_month_end
|
|
"""
|
|
import calendar
|
|
# Get the last day of the month
|
|
last_day = calendar.monthrange(date.year, date.month)[1]
|
|
return date.day == last_day |