Fix orchestrator issues
This commit is contained in:
@@ -287,7 +287,7 @@ async def generate_batch_forecast(
|
||||
from app.schemas.forecasts import BatchForecastResponse
|
||||
now = datetime.now(timezone.utc)
|
||||
return BatchForecastResponse(
|
||||
id=batch_result.get('batch_id', str(uuid.uuid4())),
|
||||
id=batch_result.get('id', str(uuid.uuid4())), # Use 'id' field (UUID) instead of 'batch_id' (string)
|
||||
tenant_id=tenant_id,
|
||||
batch_name=updated_request.batch_name,
|
||||
status="completed",
|
||||
|
||||
@@ -5,6 +5,7 @@ Main forecasting service that uses the repository pattern for data access
|
||||
|
||||
import structlog
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime, date, timedelta, timezone
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -63,8 +64,10 @@ class EnhancedForecastingService:
|
||||
"""Generate batch forecasts using repository pattern"""
|
||||
try:
|
||||
# Implementation would use repository pattern to generate multiple forecasts
|
||||
batch_uuid = uuid.uuid4()
|
||||
return {
|
||||
"batch_id": f"batch_{tenant_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
"id": str(batch_uuid), # UUID for database references
|
||||
"batch_id": f"batch_{tenant_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", # Human-readable batch identifier
|
||||
"tenant_id": tenant_id,
|
||||
"forecasts": [],
|
||||
"total_forecasts": 0,
|
||||
@@ -368,7 +371,7 @@ class EnhancedForecastingService:
|
||||
|
||||
forecast = await repos['forecast'].create_forecast(forecast_data)
|
||||
|
||||
# Step 7: Cache the prediction
|
||||
# Step 6: Cache the prediction
|
||||
await repos['cache'].cache_prediction(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
@@ -521,51 +524,62 @@ class EnhancedForecastingService:
|
||||
Generate forecast using a pre-fetched weather map to avoid multiple API calls.
|
||||
"""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
try:
|
||||
logger.info("Generating enhanced forecast with weather map",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
date=request.forecast_date.isoformat())
|
||||
|
||||
# Get session and initialize repositories
|
||||
|
||||
# CRITICAL FIX: Get model BEFORE opening database session to prevent session blocking during HTTP calls
|
||||
# This prevents holding database connections during potentially slow external API calls
|
||||
logger.debug("Fetching model data before opening database session",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id)
|
||||
|
||||
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
|
||||
|
||||
if not model_data:
|
||||
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
|
||||
|
||||
logger.debug("Model data fetched successfully",
|
||||
tenant_id=tenant_id,
|
||||
model_id=model_data.get('model_id'))
|
||||
|
||||
# Prepare features (this doesn't make external HTTP calls when using weather_map)
|
||||
features = await self._prepare_forecast_features_with_fallbacks_and_weather_map(tenant_id, request, weather_map)
|
||||
|
||||
# Now open database session AFTER external HTTP calls are complete
|
||||
async with self.database_manager.get_background_session() as session:
|
||||
repos = await self._init_repositories(session)
|
||||
|
||||
|
||||
# Step 1: Check cache first
|
||||
cached_prediction = await repos['cache'].get_cached_prediction(
|
||||
tenant_id, request.inventory_product_id, request.location, request.forecast_date
|
||||
)
|
||||
|
||||
|
||||
if cached_prediction:
|
||||
logger.debug("Using cached prediction",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id)
|
||||
return self._create_forecast_response_from_cache(cached_prediction)
|
||||
|
||||
# Step 2: Get model with validation
|
||||
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
|
||||
|
||||
if not model_data:
|
||||
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
|
||||
|
||||
# Step 3: Prepare features with fallbacks, using the weather map
|
||||
features = await self._prepare_forecast_features_with_fallbacks_and_weather_map(tenant_id, request, weather_map)
|
||||
|
||||
# Step 4: Generate prediction
|
||||
|
||||
# Step 2: Model data already fetched above (before session opened)
|
||||
|
||||
# Step 3: Generate prediction
|
||||
prediction_result = await self.prediction_service.predict(
|
||||
model_id=model_data['model_id'],
|
||||
model_path=model_data['model_path'],
|
||||
features=features,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
# Step 5: Apply business rules
|
||||
|
||||
# Step 4: Apply business rules
|
||||
adjusted_prediction = self._apply_business_rules(
|
||||
prediction_result, request, features
|
||||
)
|
||||
|
||||
# Step 6: Save forecast using repository
|
||||
|
||||
# Step 5: Save forecast using repository
|
||||
# Convert forecast_date to datetime if it's a string
|
||||
forecast_datetime = request.forecast_date
|
||||
if isinstance(forecast_datetime, str):
|
||||
@@ -599,7 +613,7 @@ class EnhancedForecastingService:
|
||||
|
||||
forecast = await repos['forecast'].create_forecast(forecast_data)
|
||||
|
||||
# Step 7: Cache the prediction
|
||||
# Step 6: Cache the prediction
|
||||
await repos['cache'].cache_prediction(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=request.inventory_product_id,
|
||||
@@ -813,32 +827,51 @@ class EnhancedForecastingService:
|
||||
|
||||
# Additional helper methods from original service
|
||||
async def _get_latest_model_with_fallback(self, tenant_id: str, inventory_product_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get the latest trained model with fallback strategies"""
|
||||
"""
|
||||
Get the latest trained model with fallback strategies.
|
||||
|
||||
CRITICAL FIX: Added timeout protection to prevent hanging during external API calls.
|
||||
This ensures we don't block indefinitely if the training service is unresponsive.
|
||||
"""
|
||||
try:
|
||||
model_data = await self.model_client.get_best_model_for_forecasting(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id
|
||||
# Add timeout protection (15 seconds) to prevent hanging
|
||||
# This is shorter than the default 30s to fail fast and avoid blocking
|
||||
model_data = await asyncio.wait_for(
|
||||
self.model_client.get_best_model_for_forecasting(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id
|
||||
),
|
||||
timeout=15.0
|
||||
)
|
||||
|
||||
|
||||
if model_data:
|
||||
logger.info("Found specific model for product",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_id=model_data.get('model_id'))
|
||||
return model_data
|
||||
|
||||
# Fallback: Try to get any model for this tenant
|
||||
fallback_model = await self.model_client.get_any_model_for_tenant(tenant_id)
|
||||
|
||||
|
||||
# Fallback: Try to get any model for this tenant (also with timeout)
|
||||
fallback_model = await asyncio.wait_for(
|
||||
self.model_client.get_any_model_for_tenant(tenant_id),
|
||||
timeout=15.0
|
||||
)
|
||||
|
||||
if fallback_model:
|
||||
logger.info("Using fallback model",
|
||||
model_id=fallback_model.get('model_id'))
|
||||
return fallback_model
|
||||
|
||||
|
||||
logger.error("No models available for tenant", tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Timeout fetching model data from training service",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
timeout_seconds=15)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Error getting model", error=str(e))
|
||||
logger.error("Error getting model", error=str(e), tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
async def _prepare_forecast_features_with_fallbacks(
|
||||
@@ -857,6 +890,9 @@ class EnhancedForecastingService:
|
||||
"week_of_year": request.forecast_date.isocalendar().week,
|
||||
"season": self._get_season(request.forecast_date.month),
|
||||
"is_holiday": self._is_spanish_holiday(request.forecast_date),
|
||||
# CRITICAL FIX: Add tenant_id and inventory_product_id for historical feature enrichment
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": request.inventory_product_id,
|
||||
}
|
||||
|
||||
# Fetch REAL weather data from external service
|
||||
@@ -951,6 +987,9 @@ class EnhancedForecastingService:
|
||||
"week_of_year": request.forecast_date.isocalendar().week,
|
||||
"season": self._get_season(request.forecast_date.month),
|
||||
"is_holiday": self._is_spanish_holiday(request.forecast_date),
|
||||
# CRITICAL FIX: Add tenant_id and inventory_product_id for historical feature enrichment
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": request.inventory_product_id,
|
||||
}
|
||||
|
||||
# Use the pre-fetched weather data from the weather map to avoid additional API calls
|
||||
|
||||
@@ -20,6 +20,7 @@ import joblib
|
||||
from app.core.config import settings
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.clients import get_sales_client
|
||||
|
||||
logger = structlog.get_logger()
|
||||
metrics = MetricsCollector("forecasting-service")
|
||||
@@ -34,6 +35,8 @@ class PredictionService:
|
||||
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "forecasting-service")
|
||||
self.model_cache = {}
|
||||
self.cache_ttl = 3600 # 1 hour cache
|
||||
# Initialize sales client for fetching historical data
|
||||
self.sales_client = get_sales_client(settings, "forecasting")
|
||||
|
||||
async def validate_prediction_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate prediction request"""
|
||||
@@ -79,7 +82,34 @@ class PredictionService:
|
||||
|
||||
if not model:
|
||||
raise ValueError(f"Model {model_id} not found or failed to load")
|
||||
|
||||
|
||||
# CRITICAL FIX: Fetch historical sales data and calculate historical features
|
||||
# This populates lag, rolling, and trend features for better predictions
|
||||
# Using 90 days for better trend analysis and more robust rolling statistics
|
||||
if 'tenant_id' in features and 'inventory_product_id' in features and 'date' in features:
|
||||
try:
|
||||
forecast_date = pd.to_datetime(features['date'])
|
||||
historical_sales = await self._fetch_historical_sales(
|
||||
tenant_id=features['tenant_id'],
|
||||
inventory_product_id=features['inventory_product_id'],
|
||||
forecast_date=forecast_date,
|
||||
days_back=90 # Changed from 30 to 90 for better historical context
|
||||
)
|
||||
|
||||
# Calculate historical features and merge into features dict
|
||||
historical_features = self._calculate_historical_features(
|
||||
historical_sales, forecast_date
|
||||
)
|
||||
features.update(historical_features)
|
||||
|
||||
logger.info("Historical features enriched",
|
||||
lag_1_day=historical_features.get('lag_1_day'),
|
||||
rolling_mean_7d=historical_features.get('rolling_mean_7d'))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enrich with historical features, using defaults",
|
||||
error=str(e))
|
||||
# Features dict will use defaults (0.0) from _prepare_prophet_features
|
||||
|
||||
# Prepare features for Prophet model
|
||||
prophet_df = self._prepare_prophet_features(features)
|
||||
|
||||
@@ -444,7 +474,222 @@ class PredictionService:
|
||||
except Exception as e:
|
||||
logger.error(f"Model validation error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def _fetch_historical_sales(
|
||||
self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
forecast_date: datetime,
|
||||
days_back: int = 90
|
||||
) -> pd.Series:
|
||||
"""
|
||||
Fetch historical sales data for calculating lagged and rolling features.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
inventory_product_id: Product UUID
|
||||
forecast_date: The date we're forecasting for
|
||||
days_back: Number of days of history to fetch (default 90 for better trend analysis)
|
||||
|
||||
Returns:
|
||||
pandas Series with sales quantities indexed by date
|
||||
"""
|
||||
try:
|
||||
# Calculate date range
|
||||
end_date = forecast_date - pd.Timedelta(days=1) # Day before forecast
|
||||
start_date = end_date - pd.Timedelta(days=days_back)
|
||||
|
||||
logger.debug("Fetching historical sales for feature calculation",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id,
|
||||
start_date=start_date.date(),
|
||||
end_date=end_date.date(),
|
||||
days_back=days_back)
|
||||
|
||||
# Fetch sales data from sales service
|
||||
sales_data = await self.sales_client.get_sales_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.strftime("%Y-%m-%d"),
|
||||
end_date=end_date.strftime("%Y-%m-%d"),
|
||||
product_id=inventory_product_id,
|
||||
aggregation="daily"
|
||||
)
|
||||
|
||||
if not sales_data:
|
||||
logger.warning("No historical sales data found",
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id)
|
||||
return pd.Series(dtype=float)
|
||||
|
||||
# Convert to pandas Series indexed by date
|
||||
df = pd.DataFrame(sales_data)
|
||||
df['sale_date'] = pd.to_datetime(df['sale_date'])
|
||||
df = df.set_index('sale_date')
|
||||
|
||||
# Extract quantity column (could be 'quantity' or 'total_quantity')
|
||||
if 'quantity' in df.columns:
|
||||
series = df['quantity']
|
||||
elif 'total_quantity' in df.columns:
|
||||
series = df['total_quantity']
|
||||
else:
|
||||
logger.warning("Sales data missing quantity field",
|
||||
columns=list(df.columns))
|
||||
return pd.Series(dtype=float)
|
||||
|
||||
logger.debug("Historical sales fetched successfully",
|
||||
records=len(series),
|
||||
date_range=f"{series.index.min()} to {series.index.max()}")
|
||||
|
||||
return series.sort_index()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error fetching historical sales",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id,
|
||||
product_id=inventory_product_id)
|
||||
return pd.Series(dtype=float)
|
||||
|
||||
def _calculate_historical_features(
|
||||
self,
|
||||
historical_sales: pd.Series,
|
||||
forecast_date: datetime
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Calculate lagged, rolling, and trend features from historical sales data.
|
||||
|
||||
Args:
|
||||
historical_sales: Series of sales quantities indexed by date
|
||||
forecast_date: The date we're forecasting for
|
||||
|
||||
Returns:
|
||||
Dictionary of calculated features
|
||||
"""
|
||||
features = {}
|
||||
|
||||
try:
|
||||
if len(historical_sales) == 0:
|
||||
logger.warning("No historical data available, using default values")
|
||||
# Return all features with default values (0.0)
|
||||
return {
|
||||
# Lagged features
|
||||
'lag_1_day': 0.0,
|
||||
'lag_7_day': 0.0,
|
||||
'lag_14_day': 0.0,
|
||||
# Rolling statistics (7-day window)
|
||||
'rolling_mean_7d': 0.0,
|
||||
'rolling_std_7d': 0.0,
|
||||
'rolling_max_7d': 0.0,
|
||||
'rolling_min_7d': 0.0,
|
||||
# Rolling statistics (14-day window)
|
||||
'rolling_mean_14d': 0.0,
|
||||
'rolling_std_14d': 0.0,
|
||||
'rolling_max_14d': 0.0,
|
||||
'rolling_min_14d': 0.0,
|
||||
# Rolling statistics (30-day window)
|
||||
'rolling_mean_30d': 0.0,
|
||||
'rolling_std_30d': 0.0,
|
||||
'rolling_max_30d': 0.0,
|
||||
'rolling_min_30d': 0.0,
|
||||
# Trend features
|
||||
'days_since_start': 0,
|
||||
'momentum_1_7': 0.0,
|
||||
'trend_7_30': 0.0,
|
||||
'velocity_week': 0.0,
|
||||
}
|
||||
|
||||
# Calculate lagged features
|
||||
features['lag_1_day'] = float(historical_sales.iloc[-1]) if len(historical_sales) >= 1 else 0.0
|
||||
features['lag_7_day'] = float(historical_sales.iloc[-7]) if len(historical_sales) >= 7 else features['lag_1_day']
|
||||
features['lag_14_day'] = float(historical_sales.iloc[-14]) if len(historical_sales) >= 14 else features['lag_7_day']
|
||||
|
||||
# Calculate rolling statistics (7-day window)
|
||||
if len(historical_sales) >= 7:
|
||||
window_7d = historical_sales.iloc[-7:]
|
||||
features['rolling_mean_7d'] = float(window_7d.mean())
|
||||
features['rolling_std_7d'] = float(window_7d.std())
|
||||
features['rolling_max_7d'] = float(window_7d.max())
|
||||
features['rolling_min_7d'] = float(window_7d.min())
|
||||
else:
|
||||
features['rolling_mean_7d'] = features['lag_1_day']
|
||||
features['rolling_std_7d'] = 0.0
|
||||
features['rolling_max_7d'] = features['lag_1_day']
|
||||
features['rolling_min_7d'] = features['lag_1_day']
|
||||
|
||||
# Calculate rolling statistics (14-day window)
|
||||
if len(historical_sales) >= 14:
|
||||
window_14d = historical_sales.iloc[-14:]
|
||||
features['rolling_mean_14d'] = float(window_14d.mean())
|
||||
features['rolling_std_14d'] = float(window_14d.std())
|
||||
features['rolling_max_14d'] = float(window_14d.max())
|
||||
features['rolling_min_14d'] = float(window_14d.min())
|
||||
else:
|
||||
features['rolling_mean_14d'] = features['rolling_mean_7d']
|
||||
features['rolling_std_14d'] = features['rolling_std_7d']
|
||||
features['rolling_max_14d'] = features['rolling_max_7d']
|
||||
features['rolling_min_14d'] = features['rolling_min_7d']
|
||||
|
||||
# Calculate rolling statistics (30-day window)
|
||||
if len(historical_sales) >= 30:
|
||||
window_30d = historical_sales.iloc[-30:]
|
||||
features['rolling_mean_30d'] = float(window_30d.mean())
|
||||
features['rolling_std_30d'] = float(window_30d.std())
|
||||
features['rolling_max_30d'] = float(window_30d.max())
|
||||
features['rolling_min_30d'] = float(window_30d.min())
|
||||
else:
|
||||
features['rolling_mean_30d'] = features['rolling_mean_14d']
|
||||
features['rolling_std_30d'] = features['rolling_std_14d']
|
||||
features['rolling_max_30d'] = features['rolling_max_14d']
|
||||
features['rolling_min_30d'] = features['rolling_min_14d']
|
||||
|
||||
# Calculate trend features
|
||||
if len(historical_sales) > 0:
|
||||
# Days since first sale
|
||||
features['days_since_start'] = (forecast_date - historical_sales.index[0]).days
|
||||
|
||||
# Momentum (difference between recent lag_1_day and lag_7_day)
|
||||
if len(historical_sales) >= 7:
|
||||
features['momentum_1_7'] = features['lag_1_day'] - features['lag_7_day']
|
||||
else:
|
||||
features['momentum_1_7'] = 0.0
|
||||
|
||||
# Trend (difference between recent 7-day and 30-day averages)
|
||||
if len(historical_sales) >= 30:
|
||||
features['trend_7_30'] = features['rolling_mean_7d'] - features['rolling_mean_30d']
|
||||
else:
|
||||
features['trend_7_30'] = 0.0
|
||||
|
||||
# Velocity (rate of change over the last week)
|
||||
if len(historical_sales) >= 7:
|
||||
week_change = historical_sales.iloc[-1] - historical_sales.iloc[-7]
|
||||
features['velocity_week'] = float(week_change / 7.0)
|
||||
else:
|
||||
features['velocity_week'] = 0.0
|
||||
else:
|
||||
features['days_since_start'] = 0
|
||||
features['momentum_1_7'] = 0.0
|
||||
features['trend_7_30'] = 0.0
|
||||
features['velocity_week'] = 0.0
|
||||
|
||||
logger.debug("Historical features calculated",
|
||||
lag_1_day=features['lag_1_day'],
|
||||
rolling_mean_7d=features['rolling_mean_7d'],
|
||||
rolling_mean_30d=features['rolling_mean_30d'],
|
||||
momentum=features['momentum_1_7'])
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error calculating historical features",
|
||||
error=str(e))
|
||||
# Return default values on error
|
||||
return {k: 0.0 for k in [
|
||||
'lag_1_day', 'lag_7_day', 'lag_14_day',
|
||||
'rolling_mean_7d', 'rolling_std_7d', 'rolling_max_7d', 'rolling_min_7d',
|
||||
'rolling_mean_14d', 'rolling_std_14d', 'rolling_max_14d', 'rolling_min_14d',
|
||||
'rolling_mean_30d', 'rolling_std_30d', 'rolling_max_30d', 'rolling_min_30d',
|
||||
'momentum_1_7', 'trend_7_30', 'velocity_week'
|
||||
]} | {'days_since_start': 0}
|
||||
|
||||
def _prepare_prophet_features(self, features: Dict[str, Any]) -> pd.DataFrame:
|
||||
"""Convert features to Prophet-compatible DataFrame - COMPLETE FEATURE MATCHING"""
|
||||
|
||||
@@ -539,6 +784,9 @@ class PredictionService:
|
||||
'is_month_start': int(forecast_date.day <= 3),
|
||||
'is_month_end': int(forecast_date.day >= 28),
|
||||
'is_payday_period': int((forecast_date.day <= 5) or (forecast_date.day >= 25)),
|
||||
# CRITICAL FIX: Add is_payday feature to match training service
|
||||
# Training defines: is_payday = (day == 15 OR is_month_end)
|
||||
'is_payday': int((forecast_date.day == 15) or self._is_end_of_month(forecast_date)),
|
||||
|
||||
# Weather-based derived features
|
||||
'temp_squared': temperature ** 2,
|
||||
@@ -600,23 +848,57 @@ class PredictionService:
|
||||
# Day features
|
||||
'is_peak_bakery_day': int(day_of_week in [4, 5, 6]),
|
||||
'is_high_demand_month': int(forecast_date.month in [6, 7, 8, 12]),
|
||||
'is_warm_season': int(forecast_date.month in [4, 5, 6, 7, 8, 9])
|
||||
'is_warm_season': int(forecast_date.month in [4, 5, 6, 7, 8, 9]),
|
||||
|
||||
# CRITICAL FIX: Cyclical encoding features (MATCH TRAINING)
|
||||
# These encode day_of_week and month as sin/cos for cyclical patterns
|
||||
'day_of_week_sin': float(np.sin(2 * np.pi * day_of_week / 7)),
|
||||
'day_of_week_cos': float(np.cos(2 * np.pi * day_of_week / 7)),
|
||||
'month_sin': float(np.sin(2 * np.pi * forecast_date.month / 12)),
|
||||
'month_cos': float(np.cos(2 * np.pi * forecast_date.month / 12)),
|
||||
|
||||
# CRITICAL FIX: Historical features (lagged, rolling, trend)
|
||||
# These will be populated from historical sales data
|
||||
# Default to 0.0 here, will be updated if historical data is provided
|
||||
'lag_1_day': float(features.get('lag_1_day', 0.0)),
|
||||
'lag_7_day': float(features.get('lag_7_day', 0.0)),
|
||||
'lag_14_day': float(features.get('lag_14_day', 0.0)),
|
||||
'rolling_mean_7d': float(features.get('rolling_mean_7d', 0.0)),
|
||||
'rolling_std_7d': float(features.get('rolling_std_7d', 0.0)),
|
||||
'rolling_max_7d': float(features.get('rolling_max_7d', 0.0)),
|
||||
'rolling_min_7d': float(features.get('rolling_min_7d', 0.0)),
|
||||
'rolling_mean_14d': float(features.get('rolling_mean_14d', 0.0)),
|
||||
'rolling_std_14d': float(features.get('rolling_std_14d', 0.0)),
|
||||
'rolling_max_14d': float(features.get('rolling_max_14d', 0.0)),
|
||||
'rolling_min_14d': float(features.get('rolling_min_14d', 0.0)),
|
||||
'rolling_mean_30d': float(features.get('rolling_mean_30d', 0.0)),
|
||||
'rolling_std_30d': float(features.get('rolling_std_30d', 0.0)),
|
||||
'rolling_max_30d': float(features.get('rolling_max_30d', 0.0)),
|
||||
'rolling_min_30d': float(features.get('rolling_min_30d', 0.0)),
|
||||
'days_since_start': int(features.get('days_since_start', 0)),
|
||||
'momentum_1_7': float(features.get('momentum_1_7', 0.0)),
|
||||
'trend_7_30': float(features.get('trend_7_30', 0.0)),
|
||||
'velocity_week': float(features.get('velocity_week', 0.0)),
|
||||
}
|
||||
|
||||
# Calculate interaction features
|
||||
is_holiday = new_features['is_holiday']
|
||||
is_pleasant = new_features['is_pleasant_day']
|
||||
is_rainy = new_features['is_rainy_day']
|
||||
|
||||
is_payday = new_features['is_payday']
|
||||
|
||||
interaction_features = {
|
||||
# Weekend interactions
|
||||
'weekend_temp_interaction': is_weekend * temperature,
|
||||
'weekend_pleasant_weather': is_weekend * is_pleasant,
|
||||
'weekend_traffic_interaction': is_weekend * traffic,
|
||||
|
||||
|
||||
# Holiday interactions
|
||||
'holiday_temp_interaction': is_holiday * temperature,
|
||||
'holiday_traffic_interaction': is_holiday * traffic,
|
||||
|
||||
# CRITICAL FIX: Add payday_weekend_interaction to match training service
|
||||
'payday_weekend_interaction': is_payday * is_weekend,
|
||||
|
||||
# Season interactions
|
||||
'season_temp_interaction': season * temperature,
|
||||
@@ -625,7 +907,11 @@ class PredictionService:
|
||||
# Rain-traffic interactions
|
||||
'rain_traffic_interaction': is_rainy * traffic,
|
||||
'rain_speed_interaction': is_rainy * avg_speed,
|
||||
|
||||
|
||||
# CRITICAL FIX: Add missing interaction features from training
|
||||
'rain_weekend_interaction': is_rainy * is_weekend,
|
||||
'friday_traffic_interaction': int(day_of_week == 4) * traffic,
|
||||
|
||||
# Day-weather interactions
|
||||
'day_temp_interaction': day_of_week * temperature,
|
||||
'month_temp_interaction': forecast_date.month * temperature,
|
||||
@@ -707,4 +993,14 @@ class PredictionService:
|
||||
elif precipitation <= 10:
|
||||
return 2 # Moderate rain
|
||||
else:
|
||||
return 3 # Heavy rain
|
||||
return 3 # Heavy rain
|
||||
|
||||
def _is_end_of_month(self, date: datetime) -> bool:
|
||||
"""
|
||||
Check if date is the last day of the month - MATCH TRAINING SERVICE
|
||||
Training uses: df[date_column].dt.is_month_end
|
||||
"""
|
||||
import calendar
|
||||
# Get the last day of the month
|
||||
last_day = calendar.monthrange(date.year, date.month)[1]
|
||||
return date.day == last_day
|
||||
@@ -34,7 +34,8 @@ route_builder = RouteBuilder('notifications')
|
||||
|
||||
# Dependency injection for enhanced notification service
|
||||
def get_enhanced_notification_service():
|
||||
database_manager = create_database_manager()
|
||||
from app.core.config import settings
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "notification")
|
||||
return EnhancedNotificationService(database_manager)
|
||||
|
||||
|
||||
@@ -47,7 +48,6 @@ def get_enhanced_notification_service():
|
||||
response_model=NotificationResponse,
|
||||
status_code=201
|
||||
)
|
||||
@require_user_role(["member", "admin", "owner"])
|
||||
@track_endpoint_metrics("notification_send")
|
||||
async def send_notification(
|
||||
notification_data: Dict[str, Any],
|
||||
@@ -55,11 +55,23 @@ async def send_notification(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
notification_service: EnhancedNotificationService = Depends(get_enhanced_notification_service)
|
||||
):
|
||||
"""Send a single notification with enhanced validation and features"""
|
||||
"""Send a single notification with enhanced validation and features - allows service-to-service calls"""
|
||||
|
||||
try:
|
||||
# Allow service-to-service calls (skip role check for service tokens)
|
||||
is_service_call = current_user.get("type") == "service"
|
||||
|
||||
if not is_service_call:
|
||||
# Check user role for non-service calls
|
||||
user_role = current_user.get("role", "").lower()
|
||||
if user_role not in ["member", "admin", "owner"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions"
|
||||
)
|
||||
|
||||
# Check permissions for broadcast notifications (Admin+ only)
|
||||
if notification_data.get("broadcast", False):
|
||||
if notification_data.get("broadcast", False) and not is_service_call:
|
||||
user_role = current_user.get("role", "").lower()
|
||||
if user_role not in ["admin", "owner"]:
|
||||
raise HTTPException(
|
||||
@@ -111,10 +123,14 @@ async def send_notification(
|
||||
detail=f"Invalid priority: {notification_data['priority']}"
|
||||
)
|
||||
|
||||
# Use tenant_id from path parameter (especially for service calls)
|
||||
effective_tenant_id = str(tenant_id) if is_service_call else current_user.get("tenant_id")
|
||||
effective_sender_id = current_user.get("user_id", "system")
|
||||
|
||||
# Create notification using enhanced service
|
||||
notification = await notification_service.create_notification(
|
||||
tenant_id=current_user.get("tenant_id"),
|
||||
sender_id=current_user["user_id"],
|
||||
tenant_id=effective_tenant_id,
|
||||
sender_id=effective_sender_id,
|
||||
notification_type=notification_type,
|
||||
message=notification_data["message"],
|
||||
recipient_id=notification_data.get("recipient_id"),
|
||||
@@ -131,18 +147,20 @@ async def send_notification(
|
||||
|
||||
logger.info("Notification sent successfully",
|
||||
notification_id=notification.id,
|
||||
tenant_id=current_user.get("tenant_id"),
|
||||
tenant_id=effective_tenant_id,
|
||||
type=notification_type.value,
|
||||
priority=priority.value)
|
||||
priority=priority.value,
|
||||
is_service_call=is_service_call)
|
||||
|
||||
return NotificationResponse.from_orm(notification)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
effective_tenant_id = str(tenant_id) if current_user.get("type") == "service" else current_user.get("tenant_id")
|
||||
logger.error("Failed to send notification",
|
||||
tenant_id=current_user.get("tenant_id"),
|
||||
sender_id=current_user["user_id"],
|
||||
tenant_id=effective_tenant_id,
|
||||
sender_id=current_user.get("user_id", "system"),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
||||
@@ -25,7 +25,8 @@ route_builder = RouteBuilder('notifications')
|
||||
|
||||
# Dependency injection for enhanced notification service
|
||||
def get_enhanced_notification_service():
|
||||
database_manager = create_database_manager()
|
||||
from app.core.config import settings
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "notification")
|
||||
return EnhancedNotificationService(database_manager)
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@@ -46,7 +46,6 @@ class EnhancedNotificationService:
|
||||
'log': self.log_repo
|
||||
}
|
||||
|
||||
@transactional
|
||||
async def create_notification(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -70,11 +69,11 @@ class EnhancedNotificationService:
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
async with UnitOfWork(db_session) as uow:
|
||||
# Register repositories
|
||||
notification_repo = uow.register_repository("notifications", NotificationRepository)
|
||||
template_repo = uow.register_repository("templates", TemplateRepository)
|
||||
preference_repo = uow.register_repository("preferences", PreferenceRepository)
|
||||
log_repo = uow.register_repository("logs", LogRepository)
|
||||
# Register repositories with model classes
|
||||
notification_repo = uow.register_repository("notifications", NotificationRepository, Notification)
|
||||
template_repo = uow.register_repository("templates", TemplateRepository, NotificationTemplate)
|
||||
preference_repo = uow.register_repository("preferences", PreferenceRepository, NotificationPreference)
|
||||
log_repo = uow.register_repository("logs", LogRepository, NotificationLog)
|
||||
|
||||
notification_data = {
|
||||
"tenant_id": tenant_id,
|
||||
|
||||
@@ -105,145 +105,7 @@ class EnhancedBakeryMLTrainer:
|
||||
return await self._execute_training_pipeline(
|
||||
tenant_id, training_dataset, job_id, session
|
||||
)
|
||||
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
|
||||
|
||||
# Note: Initial event was already published by API endpoint with estimated product count,
|
||||
# this updates with real count and recalculated time estimates based on actual data
|
||||
await publish_training_started(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=len(products),
|
||||
estimated_duration_minutes=estimated_duration_minutes,
|
||||
estimated_completion_time=estimated_completion_time.isoformat()
|
||||
)
|
||||
|
||||
# Create initial training log entry
|
||||
await repos['training_log'].update_log_progress(
|
||||
job_id, 5, "data_processing", "running"
|
||||
)
|
||||
|
||||
# ✅ FIX: Flush the session to ensure the update is committed before proceeding
|
||||
# This prevents deadlocks when training methods need to acquire locks
|
||||
await db_session.flush()
|
||||
logger.debug("Flushed session after initial progress update")
|
||||
|
||||
# Process data for each product using enhanced processor
|
||||
logger.info("Processing data using enhanced processor")
|
||||
processed_data = await self._process_all_products_enhanced(
|
||||
sales_df, weather_df, traffic_df, products, tenant_id, job_id, session
|
||||
)
|
||||
|
||||
# Categorize all products for category-specific forecasting
|
||||
logger.info("Categorizing products for optimized forecasting")
|
||||
product_categories = await self._categorize_all_products(
|
||||
sales_df, processed_data
|
||||
)
|
||||
logger.info("Product categorization complete",
|
||||
total_products=len(product_categories),
|
||||
categories_breakdown={cat.value: sum(1 for c in product_categories.values() if c == cat)
|
||||
for cat in set(product_categories.values())})
|
||||
|
||||
# Event 2: Data Analysis (20%)
|
||||
# Recalculate time remaining based on elapsed time
|
||||
start_time = await repos['training_log'].get_start_time(job_id)
|
||||
elapsed_seconds = 0
|
||||
if start_time:
|
||||
elapsed_seconds = int((datetime.now(timezone.utc) - start_time).total_seconds())
|
||||
|
||||
# Estimate remaining time: we've done ~20% of work (data analysis)
|
||||
# Remaining 80% includes training all products
|
||||
products_to_train = len(processed_data)
|
||||
estimated_remaining_seconds = int(products_to_train * avg_time_per_product)
|
||||
|
||||
# Recalculate estimated completion time
|
||||
estimated_completion_time_data_analysis = calculate_estimated_completion_time(
|
||||
estimated_remaining_seconds / 60
|
||||
)
|
||||
|
||||
await publish_data_analysis(
|
||||
job_id,
|
||||
tenant_id,
|
||||
f"Data analysis completed for {len(processed_data)} products",
|
||||
estimated_time_remaining_seconds=estimated_remaining_seconds,
|
||||
estimated_completion_time=estimated_completion_time_data_analysis.isoformat()
|
||||
)
|
||||
|
||||
# Train models for each processed product with progress aggregation
|
||||
logger.info("Training models with repository integration and progress aggregation")
|
||||
|
||||
# Create progress tracker for parallel product training (20-80%)
|
||||
progress_tracker = ParallelProductProgressTracker(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=len(processed_data)
|
||||
)
|
||||
|
||||
# Train all models in parallel (without DB writes to avoid session conflicts)
|
||||
# ✅ FIX: Pass db_session to prevent nested session issues and deadlocks
|
||||
training_results = await self._train_all_models_enhanced(
|
||||
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories, db_session
|
||||
)
|
||||
|
||||
# Write all training results to database sequentially (after parallel training completes)
|
||||
logger.info("Writing training results to database sequentially")
|
||||
training_results = await self._write_training_results_to_database(
|
||||
tenant_id, job_id, training_results, repos
|
||||
)
|
||||
|
||||
# Calculate overall training summary with enhanced metrics
|
||||
summary = await self._calculate_enhanced_training_summary(
|
||||
training_results, repos, tenant_id
|
||||
)
|
||||
|
||||
# Calculate successful and failed trainings
|
||||
successful_trainings = len([r for r in training_results.values() if r.get('status') == 'success'])
|
||||
failed_trainings = len([r for r in training_results.values() if r.get('status') == 'error'])
|
||||
total_duration = sum([r.get('training_time_seconds', 0) for r in training_results.values()])
|
||||
|
||||
# Event 4: Training Completed (100%)
|
||||
await publish_training_completed(
|
||||
job_id,
|
||||
tenant_id,
|
||||
successful_trainings,
|
||||
failed_trainings,
|
||||
total_duration
|
||||
)
|
||||
|
||||
# Create comprehensive result with repository data
|
||||
result = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "completed",
|
||||
"products_trained": len([r for r in training_results.values() if r.get('status') == 'success']),
|
||||
"products_failed": len([r for r in training_results.values() if r.get('status') == 'error']),
|
||||
"products_skipped": len([r for r in training_results.values() if r.get('status') == 'skipped']),
|
||||
"total_products": len(products),
|
||||
"training_results": training_results,
|
||||
"enhanced_summary": summary,
|
||||
"models_trained": summary.get('models_created', {}),
|
||||
"data_info": {
|
||||
"date_range": {
|
||||
"start": training_dataset.date_range.start.isoformat(),
|
||||
"end": training_dataset.date_range.end.isoformat(),
|
||||
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
|
||||
},
|
||||
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
|
||||
"constraints_applied": training_dataset.date_range.constraints
|
||||
},
|
||||
"repository_metadata": {
|
||||
"total_records_created": summary.get('total_db_records', 0),
|
||||
"performance_metrics_stored": summary.get('performance_metrics_created', 0),
|
||||
"artifacts_created": summary.get('artifacts_created', 0)
|
||||
},
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info("Enhanced ML training pipeline completed successfully",
|
||||
job_id=job_id,
|
||||
models_created=len([r for r in training_results.values() if r.get('status') == 'success']))
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Enhanced ML training pipeline failed",
|
||||
job_id=job_id,
|
||||
@@ -408,6 +270,12 @@ class EnhancedBakeryMLTrainer:
|
||||
tenant_id, job_id, training_results, repos
|
||||
)
|
||||
|
||||
# ✅ CRITICAL FIX: Commit the session to persist model records to database
|
||||
# Without this commit, all model records created above are lost when session closes
|
||||
await session.commit()
|
||||
logger.info("Committed model records to database",
|
||||
models_created=len([r for r in training_results.values() if 'model_record_id' in r]))
|
||||
|
||||
# Calculate overall training summary with enhanced metrics
|
||||
summary = await self._calculate_enhanced_training_summary(
|
||||
training_results, repos, tenant_id
|
||||
|
||||
Reference in New Issue
Block a user