Files
bakery-ia/services/forecasting/app/services/forecasting_service.py
2025-07-30 08:29:40 +02:00

500 lines
20 KiB
Python

# services/forecasting/app/services/forecasting_service.py - FIXED INITIALIZATION
"""
Enhanced forecasting service with proper ModelClient initialization
FIXED: Correct initialization order and dependency injection
"""
import structlog
from typing import Dict, List, Any, Optional
from datetime import datetime, date, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, desc
from app.models.forecasts import Forecast
from app.schemas.forecasts import ForecastRequest, ForecastResponse
from app.services.prediction_service import PredictionService
from app.core.config import settings
from app.services.model_client import ModelClient
from app.services.data_client import DataClient
logger = structlog.get_logger()
class ForecastingService:
"""Enhanced forecasting service with improved error handling"""
def __init__(self):
self.prediction_service = PredictionService()
self.model_client = ModelClient()
self.data_client = DataClient()
async def generate_forecast(
self,
tenant_id: str,
request: ForecastRequest,
db: AsyncSession
) -> ForecastResponse:
"""Generate forecast with comprehensive error handling and fallbacks"""
start_time = datetime.now()
try:
logger.info("Generating forecast",
date=request.forecast_date,
product=request.product_name,
tenant_id=tenant_id)
# Step 1: Get model with validation
model_data = await self._get_latest_model_with_fallback(tenant_id, request.product_name)
if not model_data:
raise ValueError(f"No valid model available for product: {request.product_name}")
# Enhanced model accuracy check with fallback
model_accuracy = model_data.get('mape', 0.0)
if model_accuracy == 0.0:
logger.warning("Model accuracy too low: 0.0", tenant_id=tenant_id)
logger.info("Returning model despite low accuracy - no alternative available",
tenant_id=tenant_id)
# Continue with the model but log the issue
# Step 2: Prepare features with fallbacks
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
# Step 3: Generate prediction with the model
prediction_result = await self.prediction_service.predict(
model_id=model_data['model_id'],
model_path=model_data['model_path'],
features=features,
confidence_level=request.confidence_level
)
# Step 4: Apply business rules and validation
adjusted_prediction = self._apply_business_rules(
prediction_result,
request,
features
)
# Step 5: Save forecast to database
forecast = await self._save_forecast(
db=db,
tenant_id=tenant_id,
request=request,
prediction=adjusted_prediction,
model_data=model_data,
features=features
)
logger.info("Forecast generated successfully",
forecast_id=forecast.id,
prediction=adjusted_prediction['prediction'])
return ForecastResponse(
id=str(forecast.id),
tenant_id=str(forecast.tenant_id),
product_name=forecast.product_name,
location=forecast.location,
forecast_date=forecast.forecast_date,
# Predictions
predicted_demand=forecast.predicted_demand,
confidence_lower=forecast.confidence_lower,
confidence_upper=forecast.confidence_upper,
confidence_level=forecast.confidence_level,
# Model info
model_id=str(forecast.model_id),
model_version=forecast.model_version,
algorithm=forecast.algorithm,
# Context
business_type=forecast.business_type,
is_holiday=forecast.is_holiday,
is_weekend=forecast.is_weekend,
day_of_week=forecast.day_of_week,
# External factors
weather_temperature=forecast.weather_temperature,
weather_precipitation=forecast.weather_precipitation,
weather_description=forecast.weather_description,
traffic_volume=forecast.traffic_volume,
# Metadata
created_at=forecast.created_at,
processing_time_ms=forecast.processing_time_ms,
features_used=forecast.features_used
)
except Exception as e:
logger.error("Error generating forecast",
error=str(e),
product=request.product_name,
tenant_id=tenant_id)
raise
async def _get_latest_model_with_fallback(
self,
tenant_id: str,
product_name: str
) -> Optional[Dict[str, Any]]:
"""Get the latest trained model with fallback strategies"""
try:
# Primary: Try to get the best model for this specific product
model_data = await self.model_client.get_best_model_for_forecasting(
tenant_id=tenant_id,
product_name=product_name
)
if model_data:
logger.info("Found specific model for product",
product=product_name,
model_id=model_data.get('model_id'))
return model_data
# Fallback 1: Try to get any model for this tenant
logger.warning("No specific model found, trying fallback", product=product_name)
fallback_model = await self.model_client.get_any_model_for_tenant(tenant_id)
if fallback_model:
logger.info("Using fallback model",
model_id=fallback_model.get('model_id'))
return fallback_model
# Fallback 2: Could trigger retraining here
logger.error("No models available for tenant", tenant_id=tenant_id)
return None
except Exception as e:
logger.error("Error getting model", error=str(e))
return None
async def _prepare_forecast_features_with_fallbacks(
self,
tenant_id: str,
request: ForecastRequest
) -> Dict[str, Any]:
"""Prepare features with comprehensive fallbacks for missing data"""
features = {
"date": request.forecast_date.isoformat(),
"day_of_week": request.forecast_date.weekday(),
"is_weekend": request.forecast_date.weekday() >= 5,
"day_of_month": request.forecast_date.day,
"month": request.forecast_date.month,
"quarter": (request.forecast_date.month - 1) // 3 + 1,
"week_of_year": request.forecast_date.isocalendar().week,
}
# ✅ FIX: Add season feature to match training service
features["season"] = self._get_season(request.forecast_date.month)
# Add Spanish holidays
features["is_holiday"] = self._is_spanish_holiday(request.forecast_date)
# Enhanced weather data acquisition with fallbacks
await self._add_weather_features_with_fallbacks(features, tenant_id)
# Add traffic data with fallbacks
await self._add_traffic_features_with_fallbacks(features, tenant_id)
return features
async def _add_weather_features_with_fallbacks(
self,
features: Dict[str, Any],
tenant_id: str
) -> None:
"""Add weather features with multiple fallback strategies"""
try:
# ✅ FIX: Use the corrected weather forecast call
weather_data = await self.data_client.fetch_weather_forecast(
tenant_id=tenant_id,
days=1,
latitude=40.4168, # Madrid coordinates
longitude=-3.7038
)
if weather_data and len(weather_data) > 0:
# Extract weather features from the response
weather = weather_data[0] if isinstance(weather_data, list) else weather_data
features.update({
"temperature": weather.get("temperature", 20.0),
"precipitation": weather.get("precipitation", 0.0),
"humidity": weather.get("humidity", 65.0),
"wind_speed": weather.get("wind_speed", 5.0),
"pressure": weather.get("pressure", 1013.0),
})
logger.info("Weather data acquired successfully", tenant_id=tenant_id)
return
except Exception as e:
logger.warning("Primary weather data acquisition failed", error=str(e))
# Fallback 1: Try current weather instead of forecast
try:
current_weather = await self.data_client.get_current_weather(
tenant_id=tenant_id,
latitude=40.4168,
longitude=-3.7038
)
if current_weather:
features.update({
"temperature": current_weather.get("temperature", 20.0),
"precipitation": current_weather.get("precipitation", 0.0),
"humidity": current_weather.get("humidity", 65.0),
"wind_speed": current_weather.get("wind_speed", 5.0),
"pressure": current_weather.get("pressure", 1013.0),
})
logger.info("Using current weather as fallback", tenant_id=tenant_id)
return
except Exception as e:
logger.warning("Fallback weather data acquisition failed", error=str(e))
# Fallback 2: Use seasonal averages for Madrid
month = datetime.now().month
seasonal_defaults = self._get_seasonal_weather_defaults(month)
features.update(seasonal_defaults)
logger.warning("Using seasonal weather defaults",
tenant_id=tenant_id,
defaults=seasonal_defaults)
async def _add_traffic_features_with_fallbacks(
self,
features: Dict[str, Any],
tenant_id: str
) -> None:
"""Add traffic features with fallbacks"""
# try:
# traffic_data = await self.data_client.get_traffic_data(
# tenant_id=tenant_id,
# latitude=40.4168,
# longitude=-3.7038
# )
#
# if traffic_data:
# features.update({
# "traffic_volume": traffic_data.get("traffic_volume", 100),
# "pedestrian_count": traffic_data.get("pedestrian_count", 50),
# })
# logger.info("Traffic data acquired successfully", tenant_id=tenant_id)
# return
# except Exception as e:
# logger.warning("Traffic data acquisition failed", error=str(e))
# Fallback: Use typical values based on day of week
day_of_week = features["day_of_week"]
weekend_factor = 0.7 if features["is_weekend"] else 1.0
features.update({
"traffic_volume": int(100 * weekend_factor),
"pedestrian_count": int(50 * weekend_factor),
"congestion_level": 1
})
logger.warning("Using default traffic values", tenant_id=tenant_id)
def _get_seasonal_weather_defaults(self, month: int) -> Dict[str, float]:
"""Get seasonal weather defaults for Madrid"""
# Madrid seasonal averages
seasonal_data = {
# Winter (Dec, Jan, Feb)
12: {"temperature": 9.0, "precipitation": 2.0, "humidity": 70.0, "wind_speed": 8.0},
1: {"temperature": 8.0, "precipitation": 2.5, "humidity": 72.0, "wind_speed": 7.0},
2: {"temperature": 11.0, "precipitation": 2.0, "humidity": 68.0, "wind_speed": 8.0},
# Spring (Mar, Apr, May)
3: {"temperature": 15.0, "precipitation": 1.5, "humidity": 65.0, "wind_speed": 9.0},
4: {"temperature": 18.0, "precipitation": 2.0, "humidity": 62.0, "wind_speed": 8.0},
5: {"temperature": 23.0, "precipitation": 1.8, "humidity": 58.0, "wind_speed": 7.0},
# Summer (Jun, Jul, Aug)
6: {"temperature": 29.0, "precipitation": 0.5, "humidity": 50.0, "wind_speed": 6.0},
7: {"temperature": 33.0, "precipitation": 0.2, "humidity": 45.0, "wind_speed": 5.0},
8: {"temperature": 32.0, "precipitation": 0.3, "humidity": 47.0, "wind_speed": 5.0},
# Autumn (Sep, Oct, Nov)
9: {"temperature": 26.0, "precipitation": 1.0, "humidity": 55.0, "wind_speed": 6.0},
10: {"temperature": 19.0, "precipitation": 2.5, "humidity": 65.0, "wind_speed": 7.0},
11: {"temperature": 13.0, "precipitation": 2.8, "humidity": 70.0, "wind_speed": 8.0},
}
return seasonal_data.get(month, seasonal_data[4]) # Default to April values
def _get_season(self, month: int) -> int:
"""Get season from month (1-4 for Winter, Spring, Summer, Autumn) - MATCH TRAINING"""
if month in [12, 1, 2]:
return 1 # Winter
elif month in [3, 4, 5]:
return 2 # Spring
elif month in [6, 7, 8]:
return 3 # Summer
else:
return 4 # Autumn
def _is_spanish_holiday(self, date: datetime) -> bool:
"""Check if a date is a major Spanish holiday"""
month_day = (date.month, date.day)
# Major Spanish holidays that affect bakery sales
spanish_holidays = [
(1, 1), # New Year
(1, 6), # Epiphany (Reyes)
(5, 1), # Labour Day
(8, 15), # Assumption
(10, 12), # National Day
(11, 1), # All Saints
(12, 6), # Constitution Day
(12, 8), # Immaculate Conception
(12, 25), # Christmas
]
return month_day in spanish_holidays
def _apply_business_rules(
self,
prediction: Dict[str, float],
request: ForecastRequest,
features: Dict[str, Any]
) -> Dict[str, float]:
"""Apply Spanish bakery business rules to predictions"""
base_prediction = prediction["prediction"]
lower_bound = prediction["lower_bound"]
upper_bound = prediction["upper_bound"]
# Apply adjustment factors
adjustment_factor = 1.0
# Weekend adjustment
if features.get("is_weekend", False):
adjustment_factor *= 0.8 # 20% reduction on weekends
# Holiday adjustment
if features.get("is_holiday", False):
adjustment_factor *= 0.5 # 50% reduction on holidays
# Weather adjustments
temperature = features.get("temperature", 20.0)
precipitation = features.get("precipitation", 0.0)
# Rain impact (people stay home)
if precipitation > 2.0:
adjustment_factor *= 0.7 # 30% reduction in heavy rain
elif precipitation > 0.1:
adjustment_factor *= 0.9 # 10% reduction in light rain
# Temperature impact
if temperature < 5 or temperature > 35:
adjustment_factor *= 0.8 # Extreme temperatures reduce foot traffic
elif 18 <= temperature <= 25:
adjustment_factor *= 1.1 # Pleasant weather increases activity
# Apply adjustments
adjusted_prediction = max(0, base_prediction * adjustment_factor)
adjusted_lower = max(0, lower_bound * adjustment_factor)
adjusted_upper = max(0, upper_bound * adjustment_factor)
return {
"prediction": adjusted_prediction,
"lower_bound": adjusted_lower,
"upper_bound": adjusted_upper,
"confidence_interval": adjusted_upper - adjusted_lower,
"confidence_level": prediction["confidence_level"],
"adjustment_factor": adjustment_factor
}
async def _save_forecast(
self,
db: AsyncSession,
tenant_id: str,
request: ForecastRequest,
prediction: Dict[str, float],
model_data: Dict[str, Any],
features: Dict[str, Any]
) -> Forecast:
"""Save forecast to database"""
forecast = Forecast(
tenant_id=tenant_id,
product_name=request.product_name,
location=request.location,
forecast_date=request.forecast_date,
# Predictions
predicted_demand=prediction['prediction'],
confidence_lower=prediction['lower_bound'],
confidence_upper=prediction['upper_bound'],
confidence_level=request.confidence_level,
# Model info
model_id=model_data['model_id'],
model_version=model_data.get('version', '1.0'),
algorithm=model_data.get('algorithm', 'prophet'),
# Context
business_type=features.get('business_type', 'individual'),
is_holiday=features.get('is_holiday', False),
is_weekend=features.get('is_weekend', False),
day_of_week=features.get('day_of_week', 0),
# External factors
weather_temperature=features.get('temperature'),
weather_precipitation=features.get('precipitation'),
weather_description=features.get('weather_description'),
traffic_volume=features.get('traffic_volume'),
# Metadata
processing_time_ms=int((datetime.now() - start_time).total_seconds() * 1000),
features_used=features
)
db.add(forecast)
await db.commit()
await db.refresh(forecast)
return forecast
async def get_forecast_history(
self,
tenant_id: str,
product_name: Optional[str] = None,
start_date: Optional[date] = None,
end_date: Optional[date] = None,
db: AsyncSession = None
) -> List[Forecast]:
"""Retrieve forecast history with filters"""
try:
query = select(Forecast).where(Forecast.tenant_id == tenant_id)
if product_name:
query = query.where(Forecast.product_name == product_name)
if start_date:
query = query.where(Forecast.forecast_date >= start_date)
if end_date:
query = query.where(Forecast.forecast_date <= end_date)
query = query.order_by(desc(Forecast.forecast_date))
result = await db.execute(query)
forecasts = result.scalars().all()
logger.info("Retrieved forecasts",
tenant_id=tenant_id,
count=len(forecasts))
return list(forecasts)
except Exception as e:
logger.error("Error retrieving forecasts", error=str(e))
raise