Fix orchestrator issues
This commit is contained in:
@@ -287,7 +287,7 @@ async def generate_batch_forecast(
|
||||
from app.schemas.forecasts import BatchForecastResponse
|
||||
now = datetime.now(timezone.utc)
|
||||
return BatchForecastResponse(
|
||||
id=batch_result.get('batch_id', str(uuid.uuid4())),
|
||||
id=batch_result.get('id', str(uuid.uuid4())), # Use 'id' field (UUID) instead of 'batch_id' (string)
|
||||
tenant_id=tenant_id,
|
||||
batch_name=updated_request.batch_name,
|
||||
status="completed",
|
||||
|
||||
@@ -5,6 +5,7 @@ Main forecasting service that uses the repository pattern for data access
|
||||
|
||||
import structlog
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime, date, timedelta, timezone
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -63,8 +64,10 @@ class EnhancedForecastingService:
|
||||
"""Generate batch forecasts using repository pattern"""
|
||||
try:
|
||||
# Implementation would use repository pattern to generate multiple forecasts
|
||||
batch_uuid = uuid.uuid4()
|
||||
return {
|
||||
"batch_id": f"batch_{tenant_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
"id": str(batch_uuid), # UUID for database references
|
||||
"batch_id": f"batch_{tenant_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", # Human-readable batch identifier
|
||||
"tenant_id": tenant_id,
|
||||
"forecasts": [],
|
||||
"total_forecasts": 0,
|
||||
@@ -368,7 +371,7 @@ class EnhancedForecastingService:
|
||||
|
||||
forecast = await repos['forecast'].create_forecast(forecast_data)
|
||||
|
||||
# Step 7: Cache the prediction
|
||||
# Step 6: Cache the prediction
|
||||
await repos['cache'].cache_prediction(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
@@ -521,51 +524,62 @@ class EnhancedForecastingService:
|
||||
Generate forecast using a pre-fetched weather map to avoid multiple API calls.
|
||||
"""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
try:
|
||||
logger.info("Generating enhanced forecast with weather map",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
date=request.forecast_date.isoformat())
|
||||
|
||||
# Get session and initialize repositories
|
||||
|
||||
# CRITICAL FIX: Get model BEFORE opening database session to prevent session blocking during HTTP calls
|
||||
# This prevents holding database connections during potentially slow external API calls
|
||||
logger.debug("Fetching model data before opening database session",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id)
|
||||
|
||||
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
|
||||
|
||||
if not model_data:
|
||||
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
|
||||
|
||||
logger.debug("Model data fetched successfully",
|
||||
tenant_id=tenant_id,
|
||||
model_id=model_data.get('model_id'))
|
||||
|
||||
# Prepare features (this doesn't make external HTTP calls when using weather_map)
|
||||
features = await self._prepare_forecast_features_with_fallbacks_and_weather_map(tenant_id, request, weather_map)
|
||||
|
||||
# Now open database session AFTER external HTTP calls are complete
|
||||
async with self.database_manager.get_background_session() as session:
|
||||
repos = await self._init_repositories(session)
|
||||
|
||||
|
||||
# Step 1: Check cache first
|
||||
cached_prediction = await repos['cache'].get_cached_prediction(
|
||||
tenant_id, request.inventory_product_id, request.location, request.forecast_date
|
||||
)
|
||||
|
||||
|
||||
if cached_prediction:
|
||||
logger.debug("Using cached prediction",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id)
|
||||
return self._create_forecast_response_from_cache(cached_prediction)
|
||||
|
||||
# Step 2: Get model with validation
|
||||
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
|
||||
|
||||
if not model_data:
|
||||
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
|
||||
|
||||
# Step 3: Prepare features with fallbacks, using the weather map
|
||||
features = await self._prepare_forecast_features_with_fallbacks_and_weather_map(tenant_id, request, weather_map)
|
||||
|
||||
# Step 4: Generate prediction
|
||||
|
||||
# Step 2: Model data already fetched above (before session opened)
|
||||
|
||||
# Step 3: Generate prediction
|
||||
prediction_result = await self.prediction_service.predict(
|
||||
model_id=model_data['model_id'],
|
||||
model_path=model_data['model_path'],
|
||||
features=features,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
# Step 5: Apply business rules
|
||||
|
||||
# Step 4: Apply business rules
|
||||
adjusted_prediction = self._apply_business_rules(
|
||||
prediction_result, request, features
|
||||
)
|
||||
|
||||
# Step 6: Save forecast using repository
|
||||
|
||||
# Step 5: Save forecast using repository
|
||||
# Convert forecast_date to datetime if it's a string
|
||||
forecast_datetime = request.forecast_date
|
||||
if isinstance(forecast_datetime, str):
|
||||
@@ -599,7 +613,7 @@ class EnhancedForecastingService:
|
||||
|
||||
forecast = await repos['forecast'].create_forecast(forecast_data)
|
||||
|
||||
# Step 7: Cache the prediction
|
||||
# Step 6: Cache the prediction
|
||||
await repos['cache'].cache_prediction(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
@@ -813,32 +827,51 @@ class EnhancedForecastingService:
|
||||
|
||||
# Additional helper methods from original service
|
||||
async def _get_latest_model_with_fallback(self, tenant_id: str, inventory_product_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get the latest trained model with fallback strategies"""
|
||||
"""
|
||||
Get the latest trained model with fallback strategies.
|
||||
|
||||
CRITICAL FIX: Added timeout protection to prevent hanging during external API calls.
|
||||
This ensures we don't block indefinitely if the training service is unresponsive.
|
||||
"""
|
||||
try:
|
||||
model_data = await self.model_client.get_best_model_for_forecasting(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id
|
||||
# Add timeout protection (15 seconds) to prevent hanging
|
||||
# This is shorter than the default 30s to fail fast and avoid blocking
|
||||
model_data = await asyncio.wait_for(
|
||||
self.model_client.get_best_model_for_forecasting(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id
|
||||
),
|
||||
timeout=15.0
|
||||
)
|
||||
|
||||
|
||||
if model_data:
|
||||
logger.info("Found specific model for product",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_id=model_data.get('model_id'))
|
||||
return model_data
|
||||
|
||||
# Fallback: Try to get any model for this tenant
|
||||
fallback_model = await self.model_client.get_any_model_for_tenant(tenant_id)
|
||||
|
||||
|
||||
# Fallback: Try to get any model for this tenant (also with timeout)
|
||||
fallback_model = await asyncio.wait_for(
|
||||
self.model_client.get_any_model_for_tenant(tenant_id),
|
||||
timeout=15.0
|
||||
)
|
||||
|
||||
if fallback_model:
|
||||
logger.info("Using fallback model",
|
||||
model_id=fallback_model.get('model_id'))
|
||||
return fallback_model
|
||||
|
||||
|
||||
logger.error("No models available for tenant", tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Timeout fetching model data from training service",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
timeout_seconds=15)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Error getting model", error=str(e))
|
||||
logger.error("Error getting model", error=str(e), tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
async def _prepare_forecast_features_with_fallbacks(
|
||||
@@ -857,6 +890,9 @@ class EnhancedForecastingService:
|
||||
"week_of_year": request.forecast_date.isocalendar().week,
|
||||
"season": self._get_season(request.forecast_date.month),
|
||||
"is_holiday": self._is_spanish_holiday(request.forecast_date),
|
||||
# CRITICAL FIX: Add tenant_id and inventory_product_id for historical feature enrichment
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": request.inventory_product_id,
|
||||
}
|
||||
|
||||
# Fetch REAL weather data from external service
|
||||
@@ -951,6 +987,9 @@ class EnhancedForecastingService:
|
||||
"week_of_year": request.forecast_date.isocalendar().week,
|
||||
"season": self._get_season(request.forecast_date.month),
|
||||
"is_holiday": self._is_spanish_holiday(request.forecast_date),
|
||||
# CRITICAL FIX: Add tenant_id and inventory_product_id for historical feature enrichment
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": request.inventory_product_id,
|
||||
}
|
||||
|
||||
# Use the pre-fetched weather data from the weather map to avoid additional API calls
|
||||
|
||||
@@ -20,6 +20,7 @@ import joblib
|
||||
from app.core.config import settings
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.clients import get_sales_client
|
||||
|
||||
logger = structlog.get_logger()
|
||||
metrics = MetricsCollector("forecasting-service")
|
||||
@@ -34,6 +35,8 @@ class PredictionService:
|
||||
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
self.model_cache = {}
|
||||
self.cache_ttl = 3600 # 1 hour cache
|
||||
# Initialize sales client for fetching historical data
|
||||
self.sales_client = get_sales_client(settings, "forecasting")
|
||||
|
||||
async def validate_prediction_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate prediction request"""
|
||||
@@ -79,7 +82,34 @@ class PredictionService:
|
||||
|
||||
if not model:
|
||||
raise ValueError(f"Model {model_id} not found or failed to load")
|
||||
|
||||
|
||||
# CRITICAL FIX: Fetch historical sales data and calculate historical features
|
||||
# This populates lag, rolling, and trend features for better predictions
|
||||
# Using 90 days for better trend analysis and more robust rolling statistics
|
||||
if 'tenant_id' in features and 'inventory_product_id' in features and 'date' in features:
|
||||
try:
|
||||
forecast_date = pd.to_datetime(features['date'])
|
||||
historical_sales = await self._fetch_historical_sales(
|
||||
tenant_id=features['tenant_id'],
|
||||
inventory_product_id=features['inventory_product_id'],
|
||||
forecast_date=forecast_date,
|
||||
days_back=90 # Changed from 30 to 90 for better historical context
|
||||
)
|
||||
|
||||
# Calculate historical features and merge into features dict
|
||||
historical_features = self._calculate_historical_features(
|
||||
historical_sales, forecast_date
|
||||
)
|
||||
features.update(historical_features)
|
||||
|
||||
logger.info("Historical features enriched",
|
||||
lag_1_day=historical_features.get('lag_1_day'),
|
||||
rolling_mean_7d=historical_features.get('rolling_mean_7d'))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enrich with historical features, using defaults",
|
||||
error=str(e))
|
||||
# Features dict will use defaults (0.0) from _prepare_prophet_features
|
||||
|
||||
# Prepare features for Prophet model
|
||||
prophet_df = self._prepare_prophet_features(features)
|
||||
|
||||
@@ -444,7 +474,222 @@ class PredictionService:
|
||||
except Exception as e:
|
||||
logger.error(f"Model validation error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def _fetch_historical_sales(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
forecast_date: datetime,
|
||||
days_back: int = 90
|
||||
) -> pd.Series:
|
||||
"""
|
||||
Fetch historical sales data for calculating lagged and rolling features.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
inventory_product_id: Product UUID
|
||||
forecast_date: The date we're forecasting for
|
||||
days_back: Number of days of history to fetch (default 90 for better trend analysis)
|
||||
|
||||
Returns:
|
||||
pandas Series with sales quantities indexed by date
|
||||
"""
|
||||
try:
|
||||
# Calculate date range
|
||||
end_date = forecast_date - pd.Timedelta(days=1) # Day before forecast
|
||||
start_date = end_date - pd.Timedelta(days=days_back)
|
||||
|
||||
logger.debug("Fetching historical sales for feature calculation",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
start_date=start_date.date(),
|
||||
end_date=end_date.date(),
|
||||
days_back=days_back)
|
||||
|
||||
# Fetch sales data from sales service
|
||||
sales_data = await self.sales_client.get_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.strftime("%Y-%m-%d"),
|
||||
end_date=end_date.strftime("%Y-%m-%d"),
|
||||
product_id=inventory_product_id,
|
||||
aggregation="daily"
|
||||
)
|
||||
|
||||
if not sales_data:
|
||||
logger.warning("No historical sales data found",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id)
|
||||
return pd.Series(dtype=float)
|
||||
|
||||
# Convert to pandas Series indexed by date
|
||||
df = pd.DataFrame(sales_data)
|
||||
df['sale_date'] = pd.to_datetime(df['sale_date'])
|
||||
df = df.set_index('sale_date')
|
||||
|
||||
# Extract quantity column (could be 'quantity' or 'total_quantity')
|
||||
if 'quantity' in df.columns:
|
||||
series = df['quantity']
|
||||
elif 'total_quantity' in df.columns:
|
||||
series = df['total_quantity']
|
||||
else:
|
||||
logger.warning("Sales data missing quantity field",
|
||||
columns=list(df.columns))
|
||||
return pd.Series(dtype=float)
|
||||
|
||||
logger.debug("Historical sales fetched successfully",
|
||||
records=len(series),
|
||||
date_range=f"{series.index.min()} to {series.index.max()}")
|
||||
|
||||
return series.sort_index()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error fetching historical sales",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id)
|
||||
return pd.Series(dtype=float)
|
||||
|
||||
def _calculate_historical_features(
|
||||
self,
|
||||
historical_sales: pd.Series,
|
||||
forecast_date: datetime
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Calculate lagged, rolling, and trend features from historical sales data.
|
||||
|
||||
Args:
|
||||
historical_sales: Series of sales quantities indexed by date
|
||||
forecast_date: The date we're forecasting for
|
||||
|
||||
Returns:
|
||||
Dictionary of calculated features
|
||||
"""
|
||||
features = {}
|
||||
|
||||
try:
|
||||
if len(historical_sales) == 0:
|
||||
logger.warning("No historical data available, using default values")
|
||||
# Return all features with default values (0.0)
|
||||
return {
|
||||
# Lagged features
|
||||
'lag_1_day': 0.0,
|
||||
'lag_7_day': 0.0,
|
||||
'lag_14_day': 0.0,
|
||||
# Rolling statistics (7-day window)
|
||||
'rolling_mean_7d': 0.0,
|
||||
'rolling_std_7d': 0.0,
|
||||
'rolling_max_7d': 0.0,
|
||||
'rolling_min_7d': 0.0,
|
||||
# Rolling statistics (14-day window)
|
||||
'rolling_mean_14d': 0.0,
|
||||
'rolling_std_14d': 0.0,
|
||||
'rolling_max_14d': 0.0,
|
||||
'rolling_min_14d': 0.0,
|
||||
# Rolling statistics (30-day window)
|
||||
'rolling_mean_30d': 0.0,
|
||||
'rolling_std_30d': 0.0,
|
||||
'rolling_max_30d': 0.0,
|
||||
'rolling_min_30d': 0.0,
|
||||
# Trend features
|
||||
'days_since_start': 0,
|
||||
'momentum_1_7': 0.0,
|
||||
'trend_7_30': 0.0,
|
||||
'velocity_week': 0.0,
|
||||
}
|
||||
|
||||
# Calculate lagged features
|
||||
features['lag_1_day'] = float(historical_sales.iloc[-1]) if len(historical_sales) >= 1 else 0.0
|
||||
features['lag_7_day'] = float(historical_sales.iloc[-7]) if len(historical_sales) >= 7 else features['lag_1_day']
|
||||
features['lag_14_day'] = float(historical_sales.iloc[-14]) if len(historical_sales) >= 14 else features['lag_7_day']
|
||||
|
||||
# Calculate rolling statistics (7-day window)
|
||||
if len(historical_sales) >= 7:
|
||||
window_7d = historical_sales.iloc[-7:]
|
||||
features['rolling_mean_7d'] = float(window_7d.mean())
|
||||
features['rolling_std_7d'] = float(window_7d.std())
|
||||
features['rolling_max_7d'] = float(window_7d.max())
|
||||
features['rolling_min_7d'] = float(window_7d.min())
|
||||
else:
|
||||
features['rolling_mean_7d'] = features['lag_1_day']
|
||||
features['rolling_std_7d'] = 0.0
|
||||
features['rolling_max_7d'] = features['lag_1_day']
|
||||
features['rolling_min_7d'] = features['lag_1_day']
|
||||
|
||||
# Calculate rolling statistics (14-day window)
|
||||
if len(historical_sales) >= 14:
|
||||
window_14d = historical_sales.iloc[-14:]
|
||||
features['rolling_mean_14d'] = float(window_14d.mean())
|
||||
features['rolling_std_14d'] = float(window_14d.std())
|
||||
features['rolling_max_14d'] = float(window_14d.max())
|
||||
features['rolling_min_14d'] = float(window_14d.min())
|
||||
else:
|
||||
features['rolling_mean_14d'] = features['rolling_mean_7d']
|
||||
features['rolling_std_14d'] = features['rolling_std_7d']
|
||||
features['rolling_max_14d'] = features['rolling_max_7d']
|
||||
features['rolling_min_14d'] = features['rolling_min_7d']
|
||||
|
||||
# Calculate rolling statistics (30-day window)
|
||||
if len(historical_sales) >= 30:
|
||||
window_30d = historical_sales.iloc[-30:]
|
||||
features['rolling_mean_30d'] = float(window_30d.mean())
|
||||
features['rolling_std_30d'] = float(window_30d.std())
|
||||
features['rolling_max_30d'] = float(window_30d.max())
|
||||
features['rolling_min_30d'] = float(window_30d.min())
|
||||
else:
|
||||
features['rolling_mean_30d'] = features['rolling_mean_14d']
|
||||
features['rolling_std_30d'] = features['rolling_std_14d']
|
||||
features['rolling_max_30d'] = features['rolling_max_14d']
|
||||
features['rolling_min_30d'] = features['rolling_min_14d']
|
||||
|
||||
# Calculate trend features
|
||||
if len(historical_sales) > 0:
|
||||
# Days since first sale
|
||||
features['days_since_start'] = (forecast_date - historical_sales.index[0]).days
|
||||
|
||||
# Momentum (difference between recent lag_1_day and lag_7_day)
|
||||
if len(historical_sales) >= 7:
|
||||
features['momentum_1_7'] = features['lag_1_day'] - features['lag_7_day']
|
||||
else:
|
||||
features['momentum_1_7'] = 0.0
|
||||
|
||||
# Trend (difference between recent 7-day and 30-day averages)
|
||||
if len(historical_sales) >= 30:
|
||||
features['trend_7_30'] = features['rolling_mean_7d'] - features['rolling_mean_30d']
|
||||
else:
|
||||
features['trend_7_30'] = 0.0
|
||||
|
||||
# Velocity (rate of change over the last week)
|
||||
if len(historical_sales) >= 7:
|
||||
week_change = historical_sales.iloc[-1] - historical_sales.iloc[-7]
|
||||
features['velocity_week'] = float(week_change / 7.0)
|
||||
else:
|
||||
features['velocity_week'] = 0.0
|
||||
else:
|
||||
features['days_since_start'] = 0
|
||||
features['momentum_1_7'] = 0.0
|
||||
features['trend_7_30'] = 0.0
|
||||
features['velocity_week'] = 0.0
|
||||
|
||||
logger.debug("Historical features calculated",
|
||||
lag_1_day=features['lag_1_day'],
|
||||
rolling_mean_7d=features['rolling_mean_7d'],
|
||||
rolling_mean_30d=features['rolling_mean_30d'],
|
||||
momentum=features['momentum_1_7'])
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error calculating historical features",
|
||||
error=str(e))
|
||||
# Return default values on error
|
||||
return {k: 0.0 for k in [
|
||||
'lag_1_day', 'lag_7_day', 'lag_14_day',
|
||||
'rolling_mean_7d', 'rolling_std_7d', 'rolling_max_7d', 'rolling_min_7d',
|
||||
'rolling_mean_14d', 'rolling_std_14d', 'rolling_max_14d', 'rolling_min_14d',
|
||||
'rolling_mean_30d', 'rolling_std_30d', 'rolling_max_30d', 'rolling_min_30d',
|
||||
'momentum_1_7', 'trend_7_30', 'velocity_week'
|
||||
]} | {'days_since_start': 0}
|
||||
|
||||
def _prepare_prophet_features(self, features: Dict[str, Any]) -> pd.DataFrame:
|
||||
"""Convert features to Prophet-compatible DataFrame - COMPLETE FEATURE MATCHING"""
|
||||
|
||||
@@ -539,6 +784,9 @@ class PredictionService:
|
||||
'is_month_start': int(forecast_date.day <= 3),
|
||||
'is_month_end': int(forecast_date.day >= 28),
|
||||
'is_payday_period': int((forecast_date.day <= 5) or (forecast_date.day >= 25)),
|
||||
# CRITICAL FIX: Add is_payday feature to match training service
|
||||
# Training defines: is_payday = (day == 15 OR is_month_end)
|
||||
'is_payday': int((forecast_date.day == 15) or self._is_end_of_month(forecast_date)),
|
||||
|
||||
# Weather-based derived features
|
||||
'temp_squared': temperature ** 2,
|
||||
@@ -600,23 +848,57 @@ class PredictionService:
|
||||
# Day features
|
||||
'is_peak_bakery_day': int(day_of_week in [4, 5, 6]),
|
||||
'is_high_demand_month': int(forecast_date.month in [6, 7, 8, 12]),
|
||||
'is_warm_season': int(forecast_date.month in [4, 5, 6, 7, 8, 9])
|
||||
'is_warm_season': int(forecast_date.month in [4, 5, 6, 7, 8, 9]),
|
||||
|
||||
# CRITICAL FIX: Cyclical encoding features (MATCH TRAINING)
|
||||
# These encode day_of_week and month as sin/cos for cyclical patterns
|
||||
'day_of_week_sin': float(np.sin(2 * np.pi * day_of_week / 7)),
|
||||
'day_of_week_cos': float(np.cos(2 * np.pi * day_of_week / 7)),
|
||||
'month_sin': float(np.sin(2 * np.pi * forecast_date.month / 12)),
|
||||
'month_cos': float(np.cos(2 * np.pi * forecast_date.month / 12)),
|
||||
|
||||
# CRITICAL FIX: Historical features (lagged, rolling, trend)
|
||||
# These will be populated from historical sales data
|
||||
# Default to 0.0 here, will be updated if historical data is provided
|
||||
'lag_1_day': float(features.get('lag_1_day', 0.0)),
|
||||
'lag_7_day': float(features.get('lag_7_day', 0.0)),
|
||||
'lag_14_day': float(features.get('lag_14_day', 0.0)),
|
||||
'rolling_mean_7d': float(features.get('rolling_mean_7d', 0.0)),
|
||||
'rolling_std_7d': float(features.get('rolling_std_7d', 0.0)),
|
||||
'rolling_max_7d': float(features.get('rolling_max_7d', 0.0)),
|
||||
'rolling_min_7d': float(features.get('rolling_min_7d', 0.0)),
|
||||
'rolling_mean_14d': float(features.get('rolling_mean_14d', 0.0)),
|
||||
'rolling_std_14d': float(features.get('rolling_std_14d', 0.0)),
|
||||
'rolling_max_14d': float(features.get('rolling_max_14d', 0.0)),
|
||||
'rolling_min_14d': float(features.get('rolling_min_14d', 0.0)),
|
||||
'rolling_mean_30d': float(features.get('rolling_mean_30d', 0.0)),
|
||||
'rolling_std_30d': float(features.get('rolling_std_30d', 0.0)),
|
||||
'rolling_max_30d': float(features.get('rolling_max_30d', 0.0)),
|
||||
'rolling_min_30d': float(features.get('rolling_min_30d', 0.0)),
|
||||
'days_since_start': int(features.get('days_since_start', 0)),
|
||||
'momentum_1_7': float(features.get('momentum_1_7', 0.0)),
|
||||
'trend_7_30': float(features.get('trend_7_30', 0.0)),
|
||||
'velocity_week': float(features.get('velocity_week', 0.0)),
|
||||
}
|
||||
|
||||
# Calculate interaction features
|
||||
is_holiday = new_features['is_holiday']
|
||||
is_pleasant = new_features['is_pleasant_day']
|
||||
is_rainy = new_features['is_rainy_day']
|
||||
|
||||
is_payday = new_features['is_payday']
|
||||
|
||||
interaction_features = {
|
||||
# Weekend interactions
|
||||
'weekend_temp_interaction': is_weekend * temperature,
|
||||
'weekend_pleasant_weather': is_weekend * is_pleasant,
|
||||
'weekend_traffic_interaction': is_weekend * traffic,
|
||||
|
||||
|
||||
# Holiday interactions
|
||||
'holiday_temp_interaction': is_holiday * temperature,
|
||||
'holiday_traffic_interaction': is_holiday * traffic,
|
||||
|
||||
# CRITICAL FIX: Add payday_weekend_interaction to match training service
|
||||
'payday_weekend_interaction': is_payday * is_weekend,
|
||||
|
||||
# Season interactions
|
||||
'season_temp_interaction': season * temperature,
|
||||
@@ -625,7 +907,11 @@ class PredictionService:
|
||||
# Rain-traffic interactions
|
||||
'rain_traffic_interaction': is_rainy * traffic,
|
||||
'rain_speed_interaction': is_rainy * avg_speed,
|
||||
|
||||
|
||||
# CRITICAL FIX: Add missing interaction features from training
|
||||
'rain_weekend_interaction': is_rainy * is_weekend,
|
||||
'friday_traffic_interaction': int(day_of_week == 4) * traffic,
|
||||
|
||||
# Day-weather interactions
|
||||
'day_temp_interaction': day_of_week * temperature,
|
||||
'month_temp_interaction': forecast_date.month * temperature,
|
||||
@@ -707,4 +993,14 @@ class PredictionService:
|
||||
elif precipitation <= 10:
|
||||
return 2 # Moderate rain
|
||||
else:
|
||||
return 3 # Heavy rain
|
||||
return 3 # Heavy rain
|
||||
|
||||
def _is_end_of_month(self, date: datetime) -> bool:
|
||||
"""
|
||||
Check if date is the last day of the month - MATCH TRAINING SERVICE
|
||||
Training uses: df[date_column].dt.is_month_end
|
||||
"""
|
||||
import calendar
|
||||
# Get the last day of the month
|
||||
last_day = calendar.monthrange(date.year, date.month)[1]
|
||||
return date.day == last_day
|
||||
Reference in New Issue
Block a user