Fix orchestrator issues

This commit is contained in:
Urtzi Alfaro
2025-11-05 22:54:14 +01:00
parent 80728eaa4e
commit 3ad093d38b
9 changed files with 422 additions and 484 deletions

View File

@@ -287,7 +287,7 @@ async def generate_batch_forecast(
from app.schemas.forecasts import BatchForecastResponse
now = datetime.now(timezone.utc)
return BatchForecastResponse(
id=batch_result.get('batch_id', str(uuid.uuid4())),
id=batch_result.get('id', str(uuid.uuid4())), # Use 'id' field (UUID) instead of 'batch_id' (string)
tenant_id=tenant_id,
batch_name=updated_request.batch_name,
status="completed",

View File

@@ -5,6 +5,7 @@ Main forecasting service that uses the repository pattern for data access
import structlog
import uuid
import asyncio
from typing import Dict, List, Any, Optional
from datetime import datetime, date, timedelta, timezone
from sqlalchemy.ext.asyncio import AsyncSession
@@ -63,8 +64,10 @@ class EnhancedForecastingService:
"""Generate batch forecasts using repository pattern"""
try:
# Implementation would use repository pattern to generate multiple forecasts
batch_uuid = uuid.uuid4()
return {
"batch_id": f"batch_{tenant_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
"id": str(batch_uuid), # UUID for database references
"batch_id": f"batch_{tenant_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", # Human-readable batch identifier
"tenant_id": tenant_id,
"forecasts": [],
"total_forecasts": 0,
@@ -368,7 +371,7 @@ class EnhancedForecastingService:
forecast = await repos['forecast'].create_forecast(forecast_data)
# Step 7: Cache the prediction
# Step 6: Cache the prediction
await repos['cache'].cache_prediction(
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
@@ -521,51 +524,62 @@ class EnhancedForecastingService:
Generate forecast using a pre-fetched weather map to avoid multiple API calls.
"""
start_time = datetime.now(timezone.utc)
try:
logger.info("Generating enhanced forecast with weather map",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
date=request.forecast_date.isoformat())
# Get session and initialize repositories
# CRITICAL FIX: Get model BEFORE opening database session to prevent session blocking during HTTP calls
# This prevents holding database connections during potentially slow external API calls
logger.debug("Fetching model data before opening database session",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id)
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
if not model_data:
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
logger.debug("Model data fetched successfully",
tenant_id=tenant_id,
model_id=model_data.get('model_id'))
# Prepare features (this doesn't make external HTTP calls when using weather_map)
features = await self._prepare_forecast_features_with_fallbacks_and_weather_map(tenant_id, request, weather_map)
# Now open database session AFTER external HTTP calls are complete
async with self.database_manager.get_background_session() as session:
repos = await self._init_repositories(session)
# Step 1: Check cache first
cached_prediction = await repos['cache'].get_cached_prediction(
tenant_id, request.inventory_product_id, request.location, request.forecast_date
)
if cached_prediction:
logger.debug("Using cached prediction",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id)
return self._create_forecast_response_from_cache(cached_prediction)
# Step 2: Get model with validation
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
if not model_data:
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
# Step 3: Prepare features with fallbacks, using the weather map
features = await self._prepare_forecast_features_with_fallbacks_and_weather_map(tenant_id, request, weather_map)
# Step 4: Generate prediction
# Step 2: Model data already fetched above (before session opened)
# Step 3: Generate prediction
prediction_result = await self.prediction_service.predict(
model_id=model_data['model_id'],
model_path=model_data['model_path'],
features=features,
confidence_level=request.confidence_level
)
# Step 5: Apply business rules
# Step 4: Apply business rules
adjusted_prediction = self._apply_business_rules(
prediction_result, request, features
)
# Step 6: Save forecast using repository
# Step 5: Save forecast using repository
# Convert forecast_date to datetime if it's a string
forecast_datetime = request.forecast_date
if isinstance(forecast_datetime, str):
@@ -599,7 +613,7 @@ class EnhancedForecastingService:
forecast = await repos['forecast'].create_forecast(forecast_data)
# Step 7: Cache the prediction
# Step 6: Cache the prediction
await repos['cache'].cache_prediction(
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
@@ -813,32 +827,51 @@ class EnhancedForecastingService:
# Additional helper methods from original service
async def _get_latest_model_with_fallback(self, tenant_id: str, inventory_product_id: str) -> Optional[Dict[str, Any]]:
"""Get the latest trained model with fallback strategies"""
"""
Get the latest trained model with fallback strategies.
CRITICAL FIX: Added timeout protection to prevent hanging during external API calls.
This ensures we don't block indefinitely if the training service is unresponsive.
"""
try:
model_data = await self.model_client.get_best_model_for_forecasting(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id
# Add timeout protection (15 seconds) to prevent hanging
# This is shorter than the default 30s to fail fast and avoid blocking
model_data = await asyncio.wait_for(
self.model_client.get_best_model_for_forecasting(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id
),
timeout=15.0
)
if model_data:
logger.info("Found specific model for product",
inventory_product_id=inventory_product_id,
model_id=model_data.get('model_id'))
return model_data
# Fallback: Try to get any model for this tenant
fallback_model = await self.model_client.get_any_model_for_tenant(tenant_id)
# Fallback: Try to get any model for this tenant (also with timeout)
fallback_model = await asyncio.wait_for(
self.model_client.get_any_model_for_tenant(tenant_id),
timeout=15.0
)
if fallback_model:
logger.info("Using fallback model",
model_id=fallback_model.get('model_id'))
return fallback_model
logger.error("No models available for tenant", tenant_id=tenant_id)
return None
except asyncio.TimeoutError:
logger.error("Timeout fetching model data from training service",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
timeout_seconds=15)
return None
except Exception as e:
logger.error("Error getting model", error=str(e))
logger.error("Error getting model", error=str(e), tenant_id=tenant_id)
return None
async def _prepare_forecast_features_with_fallbacks(
@@ -857,6 +890,9 @@ class EnhancedForecastingService:
"week_of_year": request.forecast_date.isocalendar().week,
"season": self._get_season(request.forecast_date.month),
"is_holiday": self._is_spanish_holiday(request.forecast_date),
# CRITICAL FIX: Add tenant_id and inventory_product_id for historical feature enrichment
"tenant_id": tenant_id,
"inventory_product_id": request.inventory_product_id,
}
# Fetch REAL weather data from external service
@@ -951,6 +987,9 @@ class EnhancedForecastingService:
"week_of_year": request.forecast_date.isocalendar().week,
"season": self._get_season(request.forecast_date.month),
"is_holiday": self._is_spanish_holiday(request.forecast_date),
# CRITICAL FIX: Add tenant_id and inventory_product_id for historical feature enrichment
"tenant_id": tenant_id,
"inventory_product_id": request.inventory_product_id,
}
# Use the pre-fetched weather data from the weather map to avoid additional API calls

View File

@@ -20,6 +20,7 @@ import joblib
from app.core.config import settings
from shared.monitoring.metrics import MetricsCollector
from shared.database.base import create_database_manager
from shared.clients import get_sales_client
logger = structlog.get_logger()
metrics = MetricsCollector("forecasting-service")
@@ -34,6 +35,8 @@ class PredictionService:
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "forecasting-service")
self.model_cache = {}
self.cache_ttl = 3600 # 1 hour cache
# Initialize sales client for fetching historical data
self.sales_client = get_sales_client(settings, "forecasting")
async def validate_prediction_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Validate prediction request"""
@@ -79,7 +82,34 @@ class PredictionService:
if not model:
raise ValueError(f"Model {model_id} not found or failed to load")
# CRITICAL FIX: Fetch historical sales data and calculate historical features
# This populates lag, rolling, and trend features for better predictions
# Using 90 days for better trend analysis and more robust rolling statistics
if 'tenant_id' in features and 'inventory_product_id' in features and 'date' in features:
try:
forecast_date = pd.to_datetime(features['date'])
historical_sales = await self._fetch_historical_sales(
tenant_id=features['tenant_id'],
inventory_product_id=features['inventory_product_id'],
forecast_date=forecast_date,
days_back=90 # Changed from 30 to 90 for better historical context
)
# Calculate historical features and merge into features dict
historical_features = self._calculate_historical_features(
historical_sales, forecast_date
)
features.update(historical_features)
logger.info("Historical features enriched",
lag_1_day=historical_features.get('lag_1_day'),
rolling_mean_7d=historical_features.get('rolling_mean_7d'))
except Exception as e:
logger.warning("Failed to enrich with historical features, using defaults",
error=str(e))
# Features dict will use defaults (0.0) from _prepare_prophet_features
# Prepare features for Prophet model
prophet_df = self._prepare_prophet_features(features)
@@ -444,7 +474,222 @@ class PredictionService:
except Exception as e:
logger.error(f"Model validation error: {e}")
return False
async def _fetch_historical_sales(
self,
tenant_id: str,
inventory_product_id: str,
forecast_date: datetime,
days_back: int = 90
) -> pd.Series:
"""
Fetch historical sales data for calculating lagged and rolling features.
Args:
tenant_id: Tenant UUID
inventory_product_id: Product UUID
forecast_date: The date we're forecasting for
days_back: Number of days of history to fetch (default 90 for better trend analysis)
Returns:
pandas Series with sales quantities indexed by date
"""
try:
# Calculate date range
end_date = forecast_date - pd.Timedelta(days=1) # Day before forecast
start_date = end_date - pd.Timedelta(days=days_back)
logger.debug("Fetching historical sales for feature calculation",
tenant_id=tenant_id,
product_id=inventory_product_id,
start_date=start_date.date(),
end_date=end_date.date(),
days_back=days_back)
# Fetch sales data from sales service
sales_data = await self.sales_client.get_sales_data(
tenant_id=tenant_id,
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
product_id=inventory_product_id,
aggregation="daily"
)
if not sales_data:
logger.warning("No historical sales data found",
tenant_id=tenant_id,
product_id=inventory_product_id)
return pd.Series(dtype=float)
# Convert to pandas Series indexed by date
df = pd.DataFrame(sales_data)
df['sale_date'] = pd.to_datetime(df['sale_date'])
df = df.set_index('sale_date')
# Extract quantity column (could be 'quantity' or 'total_quantity')
if 'quantity' in df.columns:
series = df['quantity']
elif 'total_quantity' in df.columns:
series = df['total_quantity']
else:
logger.warning("Sales data missing quantity field",
columns=list(df.columns))
return pd.Series(dtype=float)
logger.debug("Historical sales fetched successfully",
records=len(series),
date_range=f"{series.index.min()} to {series.index.max()}")
return series.sort_index()
except Exception as e:
logger.error("Error fetching historical sales",
error=str(e),
tenant_id=tenant_id,
product_id=inventory_product_id)
return pd.Series(dtype=float)
def _calculate_historical_features(
self,
historical_sales: pd.Series,
forecast_date: datetime
) -> Dict[str, float]:
"""
Calculate lagged, rolling, and trend features from historical sales data.
Args:
historical_sales: Series of sales quantities indexed by date
forecast_date: The date we're forecasting for
Returns:
Dictionary of calculated features
"""
features = {}
try:
if len(historical_sales) == 0:
logger.warning("No historical data available, using default values")
# Return all features with default values (0.0)
return {
# Lagged features
'lag_1_day': 0.0,
'lag_7_day': 0.0,
'lag_14_day': 0.0,
# Rolling statistics (7-day window)
'rolling_mean_7d': 0.0,
'rolling_std_7d': 0.0,
'rolling_max_7d': 0.0,
'rolling_min_7d': 0.0,
# Rolling statistics (14-day window)
'rolling_mean_14d': 0.0,
'rolling_std_14d': 0.0,
'rolling_max_14d': 0.0,
'rolling_min_14d': 0.0,
# Rolling statistics (30-day window)
'rolling_mean_30d': 0.0,
'rolling_std_30d': 0.0,
'rolling_max_30d': 0.0,
'rolling_min_30d': 0.0,
# Trend features
'days_since_start': 0,
'momentum_1_7': 0.0,
'trend_7_30': 0.0,
'velocity_week': 0.0,
}
# Calculate lagged features
features['lag_1_day'] = float(historical_sales.iloc[-1]) if len(historical_sales) >= 1 else 0.0
features['lag_7_day'] = float(historical_sales.iloc[-7]) if len(historical_sales) >= 7 else features['lag_1_day']
features['lag_14_day'] = float(historical_sales.iloc[-14]) if len(historical_sales) >= 14 else features['lag_7_day']
# Calculate rolling statistics (7-day window)
if len(historical_sales) >= 7:
window_7d = historical_sales.iloc[-7:]
features['rolling_mean_7d'] = float(window_7d.mean())
features['rolling_std_7d'] = float(window_7d.std())
features['rolling_max_7d'] = float(window_7d.max())
features['rolling_min_7d'] = float(window_7d.min())
else:
features['rolling_mean_7d'] = features['lag_1_day']
features['rolling_std_7d'] = 0.0
features['rolling_max_7d'] = features['lag_1_day']
features['rolling_min_7d'] = features['lag_1_day']
# Calculate rolling statistics (14-day window)
if len(historical_sales) >= 14:
window_14d = historical_sales.iloc[-14:]
features['rolling_mean_14d'] = float(window_14d.mean())
features['rolling_std_14d'] = float(window_14d.std())
features['rolling_max_14d'] = float(window_14d.max())
features['rolling_min_14d'] = float(window_14d.min())
else:
features['rolling_mean_14d'] = features['rolling_mean_7d']
features['rolling_std_14d'] = features['rolling_std_7d']
features['rolling_max_14d'] = features['rolling_max_7d']
features['rolling_min_14d'] = features['rolling_min_7d']
# Calculate rolling statistics (30-day window)
if len(historical_sales) >= 30:
window_30d = historical_sales.iloc[-30:]
features['rolling_mean_30d'] = float(window_30d.mean())
features['rolling_std_30d'] = float(window_30d.std())
features['rolling_max_30d'] = float(window_30d.max())
features['rolling_min_30d'] = float(window_30d.min())
else:
features['rolling_mean_30d'] = features['rolling_mean_14d']
features['rolling_std_30d'] = features['rolling_std_14d']
features['rolling_max_30d'] = features['rolling_max_14d']
features['rolling_min_30d'] = features['rolling_min_14d']
# Calculate trend features
if len(historical_sales) > 0:
# Days since first sale
features['days_since_start'] = (forecast_date - historical_sales.index[0]).days
# Momentum (difference between recent lag_1_day and lag_7_day)
if len(historical_sales) >= 7:
features['momentum_1_7'] = features['lag_1_day'] - features['lag_7_day']
else:
features['momentum_1_7'] = 0.0
# Trend (difference between recent 7-day and 30-day averages)
if len(historical_sales) >= 30:
features['trend_7_30'] = features['rolling_mean_7d'] - features['rolling_mean_30d']
else:
features['trend_7_30'] = 0.0
# Velocity (rate of change over the last week)
if len(historical_sales) >= 7:
week_change = historical_sales.iloc[-1] - historical_sales.iloc[-7]
features['velocity_week'] = float(week_change / 7.0)
else:
features['velocity_week'] = 0.0
else:
features['days_since_start'] = 0
features['momentum_1_7'] = 0.0
features['trend_7_30'] = 0.0
features['velocity_week'] = 0.0
logger.debug("Historical features calculated",
lag_1_day=features['lag_1_day'],
rolling_mean_7d=features['rolling_mean_7d'],
rolling_mean_30d=features['rolling_mean_30d'],
momentum=features['momentum_1_7'])
return features
except Exception as e:
logger.error("Error calculating historical features",
error=str(e))
# Return default values on error
return {k: 0.0 for k in [
'lag_1_day', 'lag_7_day', 'lag_14_day',
'rolling_mean_7d', 'rolling_std_7d', 'rolling_max_7d', 'rolling_min_7d',
'rolling_mean_14d', 'rolling_std_14d', 'rolling_max_14d', 'rolling_min_14d',
'rolling_mean_30d', 'rolling_std_30d', 'rolling_max_30d', 'rolling_min_30d',
'momentum_1_7', 'trend_7_30', 'velocity_week'
]} | {'days_since_start': 0}
def _prepare_prophet_features(self, features: Dict[str, Any]) -> pd.DataFrame:
"""Convert features to Prophet-compatible DataFrame - COMPLETE FEATURE MATCHING"""
@@ -539,6 +784,9 @@ class PredictionService:
'is_month_start': int(forecast_date.day <= 3),
'is_month_end': int(forecast_date.day >= 28),
'is_payday_period': int((forecast_date.day <= 5) or (forecast_date.day >= 25)),
# CRITICAL FIX: Add is_payday feature to match training service
# Training defines: is_payday = (day == 15 OR is_month_end)
'is_payday': int((forecast_date.day == 15) or self._is_end_of_month(forecast_date)),
# Weather-based derived features
'temp_squared': temperature ** 2,
@@ -600,23 +848,57 @@ class PredictionService:
# Day features
'is_peak_bakery_day': int(day_of_week in [4, 5, 6]),
'is_high_demand_month': int(forecast_date.month in [6, 7, 8, 12]),
'is_warm_season': int(forecast_date.month in [4, 5, 6, 7, 8, 9])
'is_warm_season': int(forecast_date.month in [4, 5, 6, 7, 8, 9]),
# CRITICAL FIX: Cyclical encoding features (MATCH TRAINING)
# These encode day_of_week and month as sin/cos for cyclical patterns
'day_of_week_sin': float(np.sin(2 * np.pi * day_of_week / 7)),
'day_of_week_cos': float(np.cos(2 * np.pi * day_of_week / 7)),
'month_sin': float(np.sin(2 * np.pi * forecast_date.month / 12)),
'month_cos': float(np.cos(2 * np.pi * forecast_date.month / 12)),
# CRITICAL FIX: Historical features (lagged, rolling, trend)
# These will be populated from historical sales data
# Default to 0.0 here, will be updated if historical data is provided
'lag_1_day': float(features.get('lag_1_day', 0.0)),
'lag_7_day': float(features.get('lag_7_day', 0.0)),
'lag_14_day': float(features.get('lag_14_day', 0.0)),
'rolling_mean_7d': float(features.get('rolling_mean_7d', 0.0)),
'rolling_std_7d': float(features.get('rolling_std_7d', 0.0)),
'rolling_max_7d': float(features.get('rolling_max_7d', 0.0)),
'rolling_min_7d': float(features.get('rolling_min_7d', 0.0)),
'rolling_mean_14d': float(features.get('rolling_mean_14d', 0.0)),
'rolling_std_14d': float(features.get('rolling_std_14d', 0.0)),
'rolling_max_14d': float(features.get('rolling_max_14d', 0.0)),
'rolling_min_14d': float(features.get('rolling_min_14d', 0.0)),
'rolling_mean_30d': float(features.get('rolling_mean_30d', 0.0)),
'rolling_std_30d': float(features.get('rolling_std_30d', 0.0)),
'rolling_max_30d': float(features.get('rolling_max_30d', 0.0)),
'rolling_min_30d': float(features.get('rolling_min_30d', 0.0)),
'days_since_start': int(features.get('days_since_start', 0)),
'momentum_1_7': float(features.get('momentum_1_7', 0.0)),
'trend_7_30': float(features.get('trend_7_30', 0.0)),
'velocity_week': float(features.get('velocity_week', 0.0)),
}
# Calculate interaction features
is_holiday = new_features['is_holiday']
is_pleasant = new_features['is_pleasant_day']
is_rainy = new_features['is_rainy_day']
is_payday = new_features['is_payday']
interaction_features = {
# Weekend interactions
'weekend_temp_interaction': is_weekend * temperature,
'weekend_pleasant_weather': is_weekend * is_pleasant,
'weekend_traffic_interaction': is_weekend * traffic,
# Holiday interactions
'holiday_temp_interaction': is_holiday * temperature,
'holiday_traffic_interaction': is_holiday * traffic,
# CRITICAL FIX: Add payday_weekend_interaction to match training service
'payday_weekend_interaction': is_payday * is_weekend,
# Season interactions
'season_temp_interaction': season * temperature,
@@ -625,7 +907,11 @@ class PredictionService:
# Rain-traffic interactions
'rain_traffic_interaction': is_rainy * traffic,
'rain_speed_interaction': is_rainy * avg_speed,
# CRITICAL FIX: Add missing interaction features from training
'rain_weekend_interaction': is_rainy * is_weekend,
'friday_traffic_interaction': int(day_of_week == 4) * traffic,
# Day-weather interactions
'day_temp_interaction': day_of_week * temperature,
'month_temp_interaction': forecast_date.month * temperature,
@@ -707,4 +993,14 @@ class PredictionService:
elif precipitation <= 10:
return 2 # Moderate rain
else:
return 3 # Heavy rain
return 3 # Heavy rain
def _is_end_of_month(self, date: datetime) -> bool:
"""
Check if date is the last day of the month - MATCH TRAINING SERVICE
Training uses: df[date_column].dt.is_month_end
"""
import calendar
# Get the last day of the month
last_day = calendar.monthrange(date.year, date.month)[1]
return date.day == last_day