500 lines
20 KiB
Python
500 lines
20 KiB
Python
# services/forecasting/app/services/forecasting_service.py - FIXED INITIALIZATION
|
|
"""
|
|
Enhanced forecasting service with proper ModelClient initialization
|
|
FIXED: Correct initialization order and dependency injection
|
|
"""
|
|
|
|
import structlog
|
|
from typing import Dict, List, Any, Optional
|
|
from datetime import datetime, date, timedelta
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_, desc
|
|
|
|
from app.models.forecasts import Forecast
|
|
from app.schemas.forecasts import ForecastRequest, ForecastResponse
|
|
from app.services.prediction_service import PredictionService
|
|
from app.core.config import settings
|
|
|
|
from app.services.model_client import ModelClient
|
|
from app.services.data_client import DataClient
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
class ForecastingService:
|
|
"""Enhanced forecasting service with improved error handling"""
|
|
|
|
def __init__(self):
|
|
self.prediction_service = PredictionService()
|
|
self.model_client = ModelClient()
|
|
self.data_client = DataClient()
|
|
|
|
async def generate_forecast(
|
|
self,
|
|
tenant_id: str,
|
|
request: ForecastRequest,
|
|
db: AsyncSession
|
|
) -> ForecastResponse:
|
|
"""Generate forecast with comprehensive error handling and fallbacks"""
|
|
|
|
start_time = datetime.now()
|
|
|
|
try:
|
|
logger.info("Generating forecast",
|
|
date=request.forecast_date,
|
|
product=request.product_name,
|
|
tenant_id=tenant_id)
|
|
|
|
# Step 1: Get model with validation
|
|
model_data = await self._get_latest_model_with_fallback(tenant_id, request.product_name)
|
|
|
|
if not model_data:
|
|
raise ValueError(f"No valid model available for product: {request.product_name}")
|
|
|
|
# Enhanced model accuracy check with fallback
|
|
model_accuracy = model_data.get('mape', 0.0)
|
|
if model_accuracy == 0.0:
|
|
logger.warning("Model accuracy too low: 0.0", tenant_id=tenant_id)
|
|
logger.info("Returning model despite low accuracy - no alternative available",
|
|
tenant_id=tenant_id)
|
|
# Continue with the model but log the issue
|
|
|
|
# Step 2: Prepare features with fallbacks
|
|
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
|
|
|
|
# Step 3: Generate prediction with the model
|
|
prediction_result = await self.prediction_service.predict(
|
|
model_id=model_data['model_id'],
|
|
model_path=model_data['model_path'],
|
|
features=features,
|
|
confidence_level=request.confidence_level
|
|
)
|
|
|
|
# Step 4: Apply business rules and validation
|
|
adjusted_prediction = self._apply_business_rules(
|
|
prediction_result,
|
|
request,
|
|
features
|
|
)
|
|
|
|
# Step 5: Save forecast to database
|
|
forecast = await self._save_forecast(
|
|
db=db,
|
|
tenant_id=tenant_id,
|
|
request=request,
|
|
prediction=adjusted_prediction,
|
|
model_data=model_data,
|
|
features=features
|
|
)
|
|
|
|
logger.info("Forecast generated successfully",
|
|
forecast_id=forecast.id,
|
|
prediction=adjusted_prediction['prediction'])
|
|
|
|
return ForecastResponse(
|
|
id=str(forecast.id),
|
|
tenant_id=str(forecast.tenant_id),
|
|
product_name=forecast.product_name,
|
|
location=forecast.location,
|
|
forecast_date=forecast.forecast_date,
|
|
|
|
# Predictions
|
|
predicted_demand=forecast.predicted_demand,
|
|
confidence_lower=forecast.confidence_lower,
|
|
confidence_upper=forecast.confidence_upper,
|
|
confidence_level=forecast.confidence_level,
|
|
|
|
# Model info
|
|
model_id=str(forecast.model_id),
|
|
model_version=forecast.model_version,
|
|
algorithm=forecast.algorithm,
|
|
|
|
# Context
|
|
business_type=forecast.business_type,
|
|
is_holiday=forecast.is_holiday,
|
|
is_weekend=forecast.is_weekend,
|
|
day_of_week=forecast.day_of_week,
|
|
|
|
# External factors
|
|
weather_temperature=forecast.weather_temperature,
|
|
weather_precipitation=forecast.weather_precipitation,
|
|
weather_description=forecast.weather_description,
|
|
traffic_volume=forecast.traffic_volume,
|
|
|
|
# Metadata
|
|
created_at=forecast.created_at,
|
|
processing_time_ms=forecast.processing_time_ms,
|
|
features_used=forecast.features_used
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("Error generating forecast",
|
|
error=str(e),
|
|
product=request.product_name,
|
|
tenant_id=tenant_id)
|
|
raise
|
|
|
|
async def _get_latest_model_with_fallback(
|
|
self,
|
|
tenant_id: str,
|
|
product_name: str
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Get the latest trained model with fallback strategies"""
|
|
try:
|
|
# Primary: Try to get the best model for this specific product
|
|
model_data = await self.model_client.get_best_model_for_forecasting(
|
|
tenant_id=tenant_id,
|
|
product_name=product_name
|
|
)
|
|
|
|
if model_data:
|
|
logger.info("Found specific model for product",
|
|
product=product_name,
|
|
model_id=model_data.get('model_id'))
|
|
return model_data
|
|
|
|
# Fallback 1: Try to get any model for this tenant
|
|
logger.warning("No specific model found, trying fallback", product=product_name)
|
|
fallback_model = await self.model_client.get_any_model_for_tenant(tenant_id)
|
|
|
|
if fallback_model:
|
|
logger.info("Using fallback model",
|
|
model_id=fallback_model.get('model_id'))
|
|
return fallback_model
|
|
|
|
# Fallback 2: Could trigger retraining here
|
|
logger.error("No models available for tenant", tenant_id=tenant_id)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error("Error getting model", error=str(e))
|
|
return None
|
|
|
|
async def _prepare_forecast_features_with_fallbacks(
|
|
self,
|
|
tenant_id: str,
|
|
request: ForecastRequest
|
|
) -> Dict[str, Any]:
|
|
"""Prepare features with comprehensive fallbacks for missing data"""
|
|
|
|
features = {
|
|
"date": request.forecast_date.isoformat(),
|
|
"day_of_week": request.forecast_date.weekday(),
|
|
"is_weekend": request.forecast_date.weekday() >= 5,
|
|
"day_of_month": request.forecast_date.day,
|
|
"month": request.forecast_date.month,
|
|
"quarter": (request.forecast_date.month - 1) // 3 + 1,
|
|
"week_of_year": request.forecast_date.isocalendar().week,
|
|
}
|
|
|
|
# ✅ FIX: Add season feature to match training service
|
|
features["season"] = self._get_season(request.forecast_date.month)
|
|
|
|
# Add Spanish holidays
|
|
features["is_holiday"] = self._is_spanish_holiday(request.forecast_date)
|
|
|
|
# Enhanced weather data acquisition with fallbacks
|
|
await self._add_weather_features_with_fallbacks(features, tenant_id)
|
|
|
|
# Add traffic data with fallbacks
|
|
await self._add_traffic_features_with_fallbacks(features, tenant_id)
|
|
|
|
return features
|
|
|
|
async def _add_weather_features_with_fallbacks(
|
|
self,
|
|
features: Dict[str, Any],
|
|
tenant_id: str
|
|
) -> None:
|
|
"""Add weather features with multiple fallback strategies"""
|
|
|
|
try:
|
|
# ✅ FIX: Use the corrected weather forecast call
|
|
weather_data = await self.data_client.fetch_weather_forecast(
|
|
tenant_id=tenant_id,
|
|
days=1,
|
|
latitude=40.4168, # Madrid coordinates
|
|
longitude=-3.7038
|
|
)
|
|
|
|
if weather_data and len(weather_data) > 0:
|
|
# Extract weather features from the response
|
|
weather = weather_data[0] if isinstance(weather_data, list) else weather_data
|
|
|
|
features.update({
|
|
"temperature": weather.get("temperature", 20.0),
|
|
"precipitation": weather.get("precipitation", 0.0),
|
|
"humidity": weather.get("humidity", 65.0),
|
|
"wind_speed": weather.get("wind_speed", 5.0),
|
|
"pressure": weather.get("pressure", 1013.0),
|
|
})
|
|
|
|
logger.info("Weather data acquired successfully", tenant_id=tenant_id)
|
|
return
|
|
|
|
except Exception as e:
|
|
logger.warning("Primary weather data acquisition failed", error=str(e))
|
|
|
|
# Fallback 1: Try current weather instead of forecast
|
|
try:
|
|
current_weather = await self.data_client.get_current_weather(
|
|
tenant_id=tenant_id,
|
|
latitude=40.4168,
|
|
longitude=-3.7038
|
|
)
|
|
|
|
if current_weather:
|
|
features.update({
|
|
"temperature": current_weather.get("temperature", 20.0),
|
|
"precipitation": current_weather.get("precipitation", 0.0),
|
|
"humidity": current_weather.get("humidity", 65.0),
|
|
"wind_speed": current_weather.get("wind_speed", 5.0),
|
|
"pressure": current_weather.get("pressure", 1013.0),
|
|
})
|
|
|
|
logger.info("Using current weather as fallback", tenant_id=tenant_id)
|
|
return
|
|
|
|
except Exception as e:
|
|
logger.warning("Fallback weather data acquisition failed", error=str(e))
|
|
|
|
# Fallback 2: Use seasonal averages for Madrid
|
|
month = datetime.now().month
|
|
seasonal_defaults = self._get_seasonal_weather_defaults(month)
|
|
features.update(seasonal_defaults)
|
|
|
|
logger.warning("Using seasonal weather defaults",
|
|
tenant_id=tenant_id,
|
|
defaults=seasonal_defaults)
|
|
|
|
async def _add_traffic_features_with_fallbacks(
|
|
self,
|
|
features: Dict[str, Any],
|
|
tenant_id: str
|
|
) -> None:
|
|
"""Add traffic features with fallbacks"""
|
|
|
|
# try:
|
|
# traffic_data = await self.data_client.get_traffic_data(
|
|
# tenant_id=tenant_id,
|
|
# latitude=40.4168,
|
|
# longitude=-3.7038
|
|
# )
|
|
#
|
|
# if traffic_data:
|
|
# features.update({
|
|
# "traffic_volume": traffic_data.get("traffic_volume", 100),
|
|
# "pedestrian_count": traffic_data.get("pedestrian_count", 50),
|
|
# })
|
|
# logger.info("Traffic data acquired successfully", tenant_id=tenant_id)
|
|
# return
|
|
|
|
# except Exception as e:
|
|
# logger.warning("Traffic data acquisition failed", error=str(e))
|
|
|
|
# Fallback: Use typical values based on day of week
|
|
day_of_week = features["day_of_week"]
|
|
weekend_factor = 0.7 if features["is_weekend"] else 1.0
|
|
|
|
features.update({
|
|
"traffic_volume": int(100 * weekend_factor),
|
|
"pedestrian_count": int(50 * weekend_factor),
|
|
"congestion_level": 1
|
|
})
|
|
|
|
logger.warning("Using default traffic values", tenant_id=tenant_id)
|
|
|
|
def _get_seasonal_weather_defaults(self, month: int) -> Dict[str, float]:
|
|
"""Get seasonal weather defaults for Madrid"""
|
|
|
|
# Madrid seasonal averages
|
|
seasonal_data = {
|
|
# Winter (Dec, Jan, Feb)
|
|
12: {"temperature": 9.0, "precipitation": 2.0, "humidity": 70.0, "wind_speed": 8.0},
|
|
1: {"temperature": 8.0, "precipitation": 2.5, "humidity": 72.0, "wind_speed": 7.0},
|
|
2: {"temperature": 11.0, "precipitation": 2.0, "humidity": 68.0, "wind_speed": 8.0},
|
|
# Spring (Mar, Apr, May)
|
|
3: {"temperature": 15.0, "precipitation": 1.5, "humidity": 65.0, "wind_speed": 9.0},
|
|
4: {"temperature": 18.0, "precipitation": 2.0, "humidity": 62.0, "wind_speed": 8.0},
|
|
5: {"temperature": 23.0, "precipitation": 1.8, "humidity": 58.0, "wind_speed": 7.0},
|
|
# Summer (Jun, Jul, Aug)
|
|
6: {"temperature": 29.0, "precipitation": 0.5, "humidity": 50.0, "wind_speed": 6.0},
|
|
7: {"temperature": 33.0, "precipitation": 0.2, "humidity": 45.0, "wind_speed": 5.0},
|
|
8: {"temperature": 32.0, "precipitation": 0.3, "humidity": 47.0, "wind_speed": 5.0},
|
|
# Autumn (Sep, Oct, Nov)
|
|
9: {"temperature": 26.0, "precipitation": 1.0, "humidity": 55.0, "wind_speed": 6.0},
|
|
10: {"temperature": 19.0, "precipitation": 2.5, "humidity": 65.0, "wind_speed": 7.0},
|
|
11: {"temperature": 13.0, "precipitation": 2.8, "humidity": 70.0, "wind_speed": 8.0},
|
|
}
|
|
|
|
return seasonal_data.get(month, seasonal_data[4]) # Default to April values
|
|
|
|
def _get_season(self, month: int) -> int:
|
|
"""Get season from month (1-4 for Winter, Spring, Summer, Autumn) - MATCH TRAINING"""
|
|
if month in [12, 1, 2]:
|
|
return 1 # Winter
|
|
elif month in [3, 4, 5]:
|
|
return 2 # Spring
|
|
elif month in [6, 7, 8]:
|
|
return 3 # Summer
|
|
else:
|
|
return 4 # Autumn
|
|
|
|
def _is_spanish_holiday(self, date: datetime) -> bool:
|
|
"""Check if a date is a major Spanish holiday"""
|
|
month_day = (date.month, date.day)
|
|
|
|
# Major Spanish holidays that affect bakery sales
|
|
spanish_holidays = [
|
|
(1, 1), # New Year
|
|
(1, 6), # Epiphany (Reyes)
|
|
(5, 1), # Labour Day
|
|
(8, 15), # Assumption
|
|
(10, 12), # National Day
|
|
(11, 1), # All Saints
|
|
(12, 6), # Constitution Day
|
|
(12, 8), # Immaculate Conception
|
|
(12, 25), # Christmas
|
|
]
|
|
|
|
return month_day in spanish_holidays
|
|
|
|
def _apply_business_rules(
|
|
self,
|
|
prediction: Dict[str, float],
|
|
request: ForecastRequest,
|
|
features: Dict[str, Any]
|
|
) -> Dict[str, float]:
|
|
"""Apply Spanish bakery business rules to predictions"""
|
|
|
|
base_prediction = prediction["prediction"]
|
|
lower_bound = prediction["lower_bound"]
|
|
upper_bound = prediction["upper_bound"]
|
|
|
|
# Apply adjustment factors
|
|
adjustment_factor = 1.0
|
|
|
|
# Weekend adjustment
|
|
if features.get("is_weekend", False):
|
|
adjustment_factor *= 0.8 # 20% reduction on weekends
|
|
|
|
# Holiday adjustment
|
|
if features.get("is_holiday", False):
|
|
adjustment_factor *= 0.5 # 50% reduction on holidays
|
|
|
|
# Weather adjustments
|
|
temperature = features.get("temperature", 20.0)
|
|
precipitation = features.get("precipitation", 0.0)
|
|
|
|
# Rain impact (people stay home)
|
|
if precipitation > 2.0:
|
|
adjustment_factor *= 0.7 # 30% reduction in heavy rain
|
|
elif precipitation > 0.1:
|
|
adjustment_factor *= 0.9 # 10% reduction in light rain
|
|
|
|
# Temperature impact
|
|
if temperature < 5 or temperature > 35:
|
|
adjustment_factor *= 0.8 # Extreme temperatures reduce foot traffic
|
|
elif 18 <= temperature <= 25:
|
|
adjustment_factor *= 1.1 # Pleasant weather increases activity
|
|
|
|
# Apply adjustments
|
|
adjusted_prediction = max(0, base_prediction * adjustment_factor)
|
|
adjusted_lower = max(0, lower_bound * adjustment_factor)
|
|
adjusted_upper = max(0, upper_bound * adjustment_factor)
|
|
|
|
return {
|
|
"prediction": adjusted_prediction,
|
|
"lower_bound": adjusted_lower,
|
|
"upper_bound": adjusted_upper,
|
|
"confidence_interval": adjusted_upper - adjusted_lower,
|
|
"confidence_level": prediction["confidence_level"],
|
|
"adjustment_factor": adjustment_factor
|
|
}
|
|
|
|
async def _save_forecast(
|
|
self,
|
|
db: AsyncSession,
|
|
tenant_id: str,
|
|
request: ForecastRequest,
|
|
prediction: Dict[str, float],
|
|
model_data: Dict[str, Any],
|
|
features: Dict[str, Any]
|
|
) -> Forecast:
|
|
"""Save forecast to database"""
|
|
|
|
forecast = Forecast(
|
|
tenant_id=tenant_id,
|
|
product_name=request.product_name,
|
|
location=request.location,
|
|
forecast_date=request.forecast_date,
|
|
|
|
# Predictions
|
|
predicted_demand=prediction['prediction'],
|
|
confidence_lower=prediction['lower_bound'],
|
|
confidence_upper=prediction['upper_bound'],
|
|
confidence_level=request.confidence_level,
|
|
|
|
# Model info
|
|
model_id=model_data['model_id'],
|
|
model_version=model_data.get('version', '1.0'),
|
|
algorithm=model_data.get('algorithm', 'prophet'),
|
|
|
|
# Context
|
|
business_type=features.get('business_type', 'individual'),
|
|
is_holiday=features.get('is_holiday', False),
|
|
is_weekend=features.get('is_weekend', False),
|
|
day_of_week=features.get('day_of_week', 0),
|
|
|
|
# External factors
|
|
weather_temperature=features.get('temperature'),
|
|
weather_precipitation=features.get('precipitation'),
|
|
weather_description=features.get('weather_description'),
|
|
traffic_volume=features.get('traffic_volume'),
|
|
|
|
# Metadata
|
|
processing_time_ms=int((datetime.now() - start_time).total_seconds() * 1000),
|
|
features_used=features
|
|
)
|
|
|
|
db.add(forecast)
|
|
await db.commit()
|
|
await db.refresh(forecast)
|
|
|
|
return forecast
|
|
|
|
async def get_forecast_history(
|
|
self,
|
|
tenant_id: str,
|
|
product_name: Optional[str] = None,
|
|
start_date: Optional[date] = None,
|
|
end_date: Optional[date] = None,
|
|
db: AsyncSession = None
|
|
) -> List[Forecast]:
|
|
"""Retrieve forecast history with filters"""
|
|
|
|
try:
|
|
query = select(Forecast).where(Forecast.tenant_id == tenant_id)
|
|
|
|
if product_name:
|
|
query = query.where(Forecast.product_name == product_name)
|
|
|
|
if start_date:
|
|
query = query.where(Forecast.forecast_date >= start_date)
|
|
|
|
if end_date:
|
|
query = query.where(Forecast.forecast_date <= end_date)
|
|
|
|
query = query.order_by(desc(Forecast.forecast_date))
|
|
|
|
result = await db.execute(query)
|
|
forecasts = result.scalars().all()
|
|
|
|
logger.info("Retrieved forecasts",
|
|
tenant_id=tenant_id,
|
|
count=len(forecasts))
|
|
|
|
return list(forecasts)
|
|
|
|
except Exception as e:
|
|
logger.error("Error retrieving forecasts", error=str(e))
|
|
raise |