Start fixing forecast service 15
This commit is contained in:
@@ -1,236 +1,442 @@
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/forecasting_service.py
|
||||
# ================================================================
|
||||
# services/forecasting/app/services/forecasting_service.py - FIXED INITIALIZATION
|
||||
"""
|
||||
Main forecasting service business logic
|
||||
Orchestrates demand prediction operations
|
||||
Enhanced forecasting service with proper ModelClient initialization
|
||||
FIXED: Correct initialization order and dependency injection
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime, date, timedelta
|
||||
import asyncio
|
||||
import uuid
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, desc
|
||||
import httpx
|
||||
|
||||
from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert
|
||||
from app.schemas.forecasts import ForecastRequest, BatchForecastRequest, BusinessType
|
||||
from app.models.forecasts import Forecast
|
||||
from app.schemas.forecasts import ForecastRequest, ForecastResponse
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.services.messaging import publish_forecast_completed, publish_alert_created
|
||||
from app.core.config import settings
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
from app.services.model_client import ModelClient
|
||||
from app.services.data_client import DataClient
|
||||
|
||||
logger = structlog.get_logger()
|
||||
metrics = MetricsCollector("forecasting-service")
|
||||
|
||||
class ForecastingService:
|
||||
"""
|
||||
Main service class for managing forecasting operations.
|
||||
Handles demand prediction, batch processing, and alert generation.
|
||||
"""
|
||||
"""Enhanced forecasting service with improved error handling"""
|
||||
|
||||
def __init__(self):
|
||||
self.prediction_service = PredictionService()
|
||||
self.model_client = ModelClient()
|
||||
self.data_client = DataClient()
|
||||
|
||||
async def generate_forecast(self, tenant_id: str, request: ForecastRequest, db: AsyncSession) -> Forecast:
|
||||
"""Generate a single forecast for a product"""
|
||||
start_time = datetime.now()
|
||||
async def generate_forecast(
|
||||
self,
|
||||
tenant_id: str,
|
||||
request: ForecastRequest,
|
||||
db: AsyncSession
|
||||
) -> ForecastResponse:
|
||||
"""Generate forecast with comprehensive error handling and fallbacks"""
|
||||
|
||||
try:
|
||||
logger.info("Generating forecast",
|
||||
tenant_id=tenant_id,
|
||||
date=request.forecast_date,
|
||||
product=request.product_name,
|
||||
date=request.forecast_date)
|
||||
tenant_id=tenant_id)
|
||||
|
||||
# Get the latest trained model for this tenant/product
|
||||
model_info = await self._get_latest_model(
|
||||
tenant_id,
|
||||
request.product_name,
|
||||
)
|
||||
# Step 1: Get model with validation
|
||||
model_data = await self._get_latest_model_with_fallback(tenant_id, request.product_name)
|
||||
|
||||
if not model_info:
|
||||
raise ValueError(f"No trained model found for {request.product_name}")
|
||||
if not model_data:
|
||||
raise ValueError(f"No valid model available for product: {request.product_name}")
|
||||
|
||||
# Prepare features for prediction
|
||||
features = await self._prepare_forecast_features(tenant_id, request)
|
||||
# Enhanced model accuracy check with fallback
|
||||
model_accuracy = model_data.get('mape', 0.0)
|
||||
if model_accuracy == 0.0:
|
||||
logger.warning("Model accuracy too low: 0.0", tenant_id=tenant_id)
|
||||
logger.info("Returning model despite low accuracy - no alternative available",
|
||||
tenant_id=tenant_id)
|
||||
# Continue with the model but log the issue
|
||||
|
||||
# Generate prediction using ML service
|
||||
# Step 2: Prepare features with fallbacks
|
||||
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
|
||||
|
||||
# Step 3: Generate prediction with the model
|
||||
prediction_result = await self.prediction_service.predict(
|
||||
model_id=model_info["model_id"],
|
||||
model_path=model_info["model_path"],
|
||||
model_id=model_data['model_id'],
|
||||
model_path=model_data['model_path'],
|
||||
features=features,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
# Create forecast record
|
||||
forecast = Forecast(
|
||||
tenant_id=uuid.UUID(tenant_id),
|
||||
product_name=request.product_name,
|
||||
forecast_date=datetime.combine(request.forecast_date, datetime.min.time()),
|
||||
|
||||
# Prediction results
|
||||
predicted_demand=prediction_result["demand"],
|
||||
confidence_lower=prediction_result["lower_bound"],
|
||||
confidence_upper=prediction_result["upper_bound"],
|
||||
confidence_level=request.confidence_level,
|
||||
|
||||
# Model information
|
||||
model_id=uuid.UUID(model_info["model_id"]),
|
||||
model_version=model_info["version"],
|
||||
algorithm=model_info.get("algorithm", "prophet"),
|
||||
|
||||
# Context
|
||||
business_type=request.business_type.value,
|
||||
day_of_week=request.forecast_date.weekday(),
|
||||
is_holiday=features.get("is_holiday", False),
|
||||
is_weekend=request.forecast_date.weekday() >= 5,
|
||||
|
||||
# External factors
|
||||
weather_temperature=features.get("temperature"),
|
||||
weather_precipitation=features.get("precipitation"),
|
||||
weather_description=features.get("weather_description"),
|
||||
traffic_volume=features.get("traffic_volume"),
|
||||
|
||||
# Metadata
|
||||
processing_time_ms=int((datetime.now() - start_time).total_seconds() * 1000),
|
||||
features_used=features
|
||||
# Step 4: Apply business rules and validation
|
||||
adjusted_prediction = self._apply_business_rules(
|
||||
prediction_result,
|
||||
request,
|
||||
features
|
||||
)
|
||||
|
||||
db.add(forecast)
|
||||
await db.commit()
|
||||
await db.refresh(forecast)
|
||||
|
||||
# Check for alerts
|
||||
await self._check_and_create_alerts(forecast, db)
|
||||
|
||||
# Update metrics
|
||||
metrics.increment_counter("forecasts_generated_total",
|
||||
{"product": request.product_name, "location": request.location})
|
||||
|
||||
# Publish event
|
||||
await publish_forecast_completed({
|
||||
"forecast_id": str(forecast.id),
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": request.product_name,
|
||||
"predicted_demand": forecast.predicted_demand
|
||||
})
|
||||
# Step 5: Save forecast to database
|
||||
forecast = await self._save_forecast(
|
||||
db=db,
|
||||
tenant_id=tenant_id,
|
||||
request=request,
|
||||
prediction=adjusted_prediction,
|
||||
model_data=model_data,
|
||||
features=features
|
||||
)
|
||||
|
||||
logger.info("Forecast generated successfully",
|
||||
forecast_id=str(forecast.id),
|
||||
predicted_demand=forecast.predicted_demand)
|
||||
forecast_id=forecast.id,
|
||||
prediction=adjusted_prediction['prediction'])
|
||||
|
||||
return forecast
|
||||
return ForecastResponse(
|
||||
id=forecast.id,
|
||||
forecast_date=forecast.forecast_date,
|
||||
product_name=forecast.product_name,
|
||||
predicted_quantity=forecast.predicted_quantity,
|
||||
confidence_level=forecast.confidence_level,
|
||||
lower_bound=forecast.lower_bound,
|
||||
upper_bound=forecast.upper_bound,
|
||||
model_id=forecast.model_id,
|
||||
created_at=forecast.created_at,
|
||||
external_factors=forecast.external_factors
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error generating forecast",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id,
|
||||
product=request.product_name)
|
||||
product=request.product_name,
|
||||
tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def generate_batch_forecast(self, request: BatchForecastRequest, db: AsyncSession) -> PredictionBatch:
|
||||
"""Generate forecasts for multiple products over multiple days"""
|
||||
|
||||
async def _get_latest_model_with_fallback(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get the latest trained model with fallback strategies"""
|
||||
try:
|
||||
logger.info("Starting batch forecast generation",
|
||||
tenant_id=request.tenant_id,
|
||||
batch_name=request.batch_name,
|
||||
products_count=len(request.products),
|
||||
forecast_days=request.forecast_days)
|
||||
|
||||
# Create batch record
|
||||
batch = PredictionBatch(
|
||||
tenant_id=uuid.UUID(request.tenant_id),
|
||||
batch_name=request.batch_name,
|
||||
status="processing",
|
||||
total_products=len(request.products) * request.forecast_days,
|
||||
business_type=request.business_type.value,
|
||||
forecast_days=request.forecast_days
|
||||
# Primary: Try to get the best model for this specific product
|
||||
model_data = await self.model_client.get_best_model_for_forecasting(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
db.add(batch)
|
||||
await db.commit()
|
||||
await db.refresh(batch)
|
||||
if model_data:
|
||||
logger.info("Found specific model for product",
|
||||
product=product_name,
|
||||
model_id=model_data.get('model_id'))
|
||||
return model_data
|
||||
|
||||
# Generate forecasts for each product and day
|
||||
completed_count = 0
|
||||
failed_count = 0
|
||||
# Fallback 1: Try to get any model for this tenant
|
||||
logger.warning("No specific model found, trying fallback", product=product_name)
|
||||
fallback_model = await self.model_client.get_any_model_for_tenant(tenant_id)
|
||||
|
||||
for product in request.products:
|
||||
for day_offset in range(request.forecast_days):
|
||||
forecast_date = date.today() + timedelta(days=day_offset + 1)
|
||||
|
||||
try:
|
||||
forecast_request = ForecastRequest(
|
||||
tenant_id=request.tenant_id,
|
||||
product_name=product,
|
||||
location=request.location,
|
||||
forecast_date=forecast_date,
|
||||
business_type=request.business_type,
|
||||
include_weather=request.include_weather,
|
||||
include_traffic=request.include_traffic,
|
||||
confidence_level=request.confidence_level
|
||||
)
|
||||
|
||||
await self.generate_forecast(forecast_request, db)
|
||||
completed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to generate forecast for product",
|
||||
product=product,
|
||||
date=forecast_date,
|
||||
error=str(e))
|
||||
failed_count += 1
|
||||
if fallback_model:
|
||||
logger.info("Using fallback model",
|
||||
model_id=fallback_model.get('model_id'))
|
||||
return fallback_model
|
||||
|
||||
# Update batch status
|
||||
batch.status = "completed" if failed_count == 0 else "partial"
|
||||
batch.completed_products = completed_count
|
||||
batch.failed_products = failed_count
|
||||
batch.completed_at = datetime.now()
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info("Batch forecast generation completed",
|
||||
batch_id=str(batch.id),
|
||||
completed=completed_count,
|
||||
failed=failed_count)
|
||||
|
||||
return batch
|
||||
# Fallback 2: Could trigger retraining here
|
||||
logger.error("No models available for tenant", tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in batch forecast generation", error=str(e))
|
||||
raise
|
||||
logger.error("Error getting model", error=str(e))
|
||||
return None
|
||||
|
||||
async def get_forecasts(self, tenant_id: str, location: str,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
product_name: Optional[str] = None,
|
||||
db: AsyncSession = None) -> List[Forecast]:
|
||||
"""Retrieve forecasts with filtering"""
|
||||
async def _prepare_forecast_features_with_fallbacks(
|
||||
self,
|
||||
tenant_id: str,
|
||||
request: ForecastRequest
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare features with comprehensive fallbacks for missing data"""
|
||||
|
||||
features = {
|
||||
"date": request.forecast_date.isoformat(),
|
||||
"day_of_week": request.forecast_date.weekday(),
|
||||
"is_weekend": request.forecast_date.weekday() >= 5,
|
||||
"day_of_month": request.forecast_date.day,
|
||||
"month": request.forecast_date.month,
|
||||
"quarter": (request.forecast_date.month - 1) // 3 + 1,
|
||||
"week_of_year": request.forecast_date.isocalendar().week,
|
||||
}
|
||||
|
||||
# ✅ FIX: Add season feature to match training service
|
||||
features["season"] = self._get_season(request.forecast_date.month)
|
||||
|
||||
# Add Spanish holidays
|
||||
features["is_holiday"] = self._is_spanish_holiday(request.forecast_date)
|
||||
|
||||
# Enhanced weather data acquisition with fallbacks
|
||||
await self._add_weather_features_with_fallbacks(features, tenant_id)
|
||||
|
||||
# Add traffic data with fallbacks
|
||||
# await self._add_traffic_features_with_fallbacks(features, tenant_id)
|
||||
|
||||
return features
|
||||
|
||||
async def _add_weather_features_with_fallbacks(
|
||||
self,
|
||||
features: Dict[str, Any],
|
||||
tenant_id: str
|
||||
) -> None:
|
||||
"""Add weather features with multiple fallback strategies"""
|
||||
|
||||
try:
|
||||
query = select(Forecast).where(
|
||||
and_(
|
||||
Forecast.tenant_id == uuid.UUID(tenant_id),
|
||||
Forecast.location == location
|
||||
)
|
||||
# ✅ FIX: Use the corrected weather forecast call
|
||||
weather_data = await self.data_client.fetch_weather_forecast(
|
||||
tenant_id=tenant_id,
|
||||
days=1,
|
||||
latitude=40.4168, # Madrid coordinates
|
||||
longitude=-3.7038
|
||||
)
|
||||
|
||||
if start_date:
|
||||
query = query.where(Forecast.forecast_date >= datetime.combine(start_date, datetime.min.time()))
|
||||
if weather_data and len(weather_data) > 0:
|
||||
# Extract weather features from the response
|
||||
weather = weather_data[0] if isinstance(weather_data, list) else weather_data
|
||||
|
||||
features.update({
|
||||
"temperature": weather.get("temperature", 20.0),
|
||||
"precipitation": weather.get("precipitation", 0.0),
|
||||
"humidity": weather.get("humidity", 65.0),
|
||||
"wind_speed": weather.get("wind_speed", 5.0),
|
||||
"pressure": weather.get("pressure", 1013.0),
|
||||
})
|
||||
|
||||
logger.info("Weather data acquired successfully", tenant_id=tenant_id)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Primary weather data acquisition failed", error=str(e))
|
||||
|
||||
# Fallback 1: Try current weather instead of forecast
|
||||
try:
|
||||
current_weather = await self.data_client.get_current_weather(
|
||||
tenant_id=tenant_id,
|
||||
latitude=40.4168,
|
||||
longitude=-3.7038
|
||||
)
|
||||
|
||||
if end_date:
|
||||
query = query.where(Forecast.forecast_date <= datetime.combine(end_date, datetime.max.time()))
|
||||
if current_weather:
|
||||
features.update({
|
||||
"temperature": current_weather.get("temperature", 20.0),
|
||||
"precipitation": current_weather.get("precipitation", 0.0),
|
||||
"humidity": current_weather.get("humidity", 65.0),
|
||||
"wind_speed": current_weather.get("wind_speed", 5.0),
|
||||
"pressure": current_weather.get("pressure", 1013.0),
|
||||
})
|
||||
|
||||
logger.info("Using current weather as fallback", tenant_id=tenant_id)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Fallback weather data acquisition failed", error=str(e))
|
||||
|
||||
# Fallback 2: Use seasonal averages for Madrid
|
||||
month = datetime.now().month
|
||||
seasonal_defaults = self._get_seasonal_weather_defaults(month)
|
||||
features.update(seasonal_defaults)
|
||||
|
||||
logger.warning("Using seasonal weather defaults",
|
||||
tenant_id=tenant_id,
|
||||
defaults=seasonal_defaults)
|
||||
|
||||
async def _add_traffic_features_with_fallbacks(
|
||||
self,
|
||||
features: Dict[str, Any],
|
||||
tenant_id: str
|
||||
) -> None:
|
||||
"""Add traffic features with fallbacks"""
|
||||
|
||||
try:
|
||||
traffic_data = await self.data_client.get_traffic_data(
|
||||
tenant_id=tenant_id,
|
||||
latitude=40.4168,
|
||||
longitude=-3.7038
|
||||
)
|
||||
|
||||
if traffic_data:
|
||||
features.update({
|
||||
"traffic_volume": traffic_data.get("traffic_volume", 100),
|
||||
"pedestrian_count": traffic_data.get("pedestrian_count", 50),
|
||||
})
|
||||
logger.info("Traffic data acquired successfully", tenant_id=tenant_id)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Traffic data acquisition failed", error=str(e))
|
||||
|
||||
# Fallback: Use typical values based on day of week
|
||||
day_of_week = features["day_of_week"]
|
||||
weekend_factor = 0.7 if features["is_weekend"] else 1.0
|
||||
|
||||
features.update({
|
||||
"traffic_volume": int(100 * weekend_factor),
|
||||
"pedestrian_count": int(50 * weekend_factor),
|
||||
})
|
||||
|
||||
logger.warning("Using default traffic values", tenant_id=tenant_id)
|
||||
|
||||
def _get_seasonal_weather_defaults(self, month: int) -> Dict[str, float]:
|
||||
"""Get seasonal weather defaults for Madrid"""
|
||||
|
||||
# Madrid seasonal averages
|
||||
seasonal_data = {
|
||||
# Winter (Dec, Jan, Feb)
|
||||
12: {"temperature": 9.0, "precipitation": 2.0, "humidity": 70.0, "wind_speed": 8.0},
|
||||
1: {"temperature": 8.0, "precipitation": 2.5, "humidity": 72.0, "wind_speed": 7.0},
|
||||
2: {"temperature": 11.0, "precipitation": 2.0, "humidity": 68.0, "wind_speed": 8.0},
|
||||
# Spring (Mar, Apr, May)
|
||||
3: {"temperature": 15.0, "precipitation": 1.5, "humidity": 65.0, "wind_speed": 9.0},
|
||||
4: {"temperature": 18.0, "precipitation": 2.0, "humidity": 62.0, "wind_speed": 8.0},
|
||||
5: {"temperature": 23.0, "precipitation": 1.8, "humidity": 58.0, "wind_speed": 7.0},
|
||||
# Summer (Jun, Jul, Aug)
|
||||
6: {"temperature": 29.0, "precipitation": 0.5, "humidity": 50.0, "wind_speed": 6.0},
|
||||
7: {"temperature": 33.0, "precipitation": 0.2, "humidity": 45.0, "wind_speed": 5.0},
|
||||
8: {"temperature": 32.0, "precipitation": 0.3, "humidity": 47.0, "wind_speed": 5.0},
|
||||
# Autumn (Sep, Oct, Nov)
|
||||
9: {"temperature": 26.0, "precipitation": 1.0, "humidity": 55.0, "wind_speed": 6.0},
|
||||
10: {"temperature": 19.0, "precipitation": 2.5, "humidity": 65.0, "wind_speed": 7.0},
|
||||
11: {"temperature": 13.0, "precipitation": 2.8, "humidity": 70.0, "wind_speed": 8.0},
|
||||
}
|
||||
|
||||
return seasonal_data.get(month, seasonal_data[4]) # Default to April values
|
||||
|
||||
def _get_season(self, month: int) -> int:
|
||||
"""Get season from month (1-4 for Winter, Spring, Summer, Autumn) - MATCH TRAINING"""
|
||||
if month in [12, 1, 2]:
|
||||
return 1 # Winter
|
||||
elif month in [3, 4, 5]:
|
||||
return 2 # Spring
|
||||
elif month in [6, 7, 8]:
|
||||
return 3 # Summer
|
||||
else:
|
||||
return 4 # Autumn
|
||||
|
||||
def _is_spanish_holiday(self, date: datetime) -> bool:
|
||||
"""Check if a date is a major Spanish holiday"""
|
||||
month_day = (date.month, date.day)
|
||||
|
||||
# Major Spanish holidays that affect bakery sales
|
||||
spanish_holidays = [
|
||||
(1, 1), # New Year
|
||||
(1, 6), # Epiphany (Reyes)
|
||||
(5, 1), # Labour Day
|
||||
(8, 15), # Assumption
|
||||
(10, 12), # National Day
|
||||
(11, 1), # All Saints
|
||||
(12, 6), # Constitution Day
|
||||
(12, 8), # Immaculate Conception
|
||||
(12, 25), # Christmas
|
||||
]
|
||||
|
||||
return month_day in spanish_holidays
|
||||
|
||||
def _apply_business_rules(
|
||||
self,
|
||||
prediction: Dict[str, float],
|
||||
request: ForecastRequest,
|
||||
features: Dict[str, Any]
|
||||
) -> Dict[str, float]:
|
||||
"""Apply Spanish bakery business rules to predictions"""
|
||||
|
||||
base_prediction = prediction["prediction"]
|
||||
lower_bound = prediction["lower_bound"]
|
||||
upper_bound = prediction["upper_bound"]
|
||||
|
||||
# Apply adjustment factors
|
||||
adjustment_factor = 1.0
|
||||
|
||||
# Weekend adjustment
|
||||
if features.get("is_weekend", False):
|
||||
adjustment_factor *= 0.8 # 20% reduction on weekends
|
||||
|
||||
# Holiday adjustment
|
||||
if features.get("is_holiday", False):
|
||||
adjustment_factor *= 0.5 # 50% reduction on holidays
|
||||
|
||||
# Weather adjustments
|
||||
temperature = features.get("temperature", 20.0)
|
||||
precipitation = features.get("precipitation", 0.0)
|
||||
|
||||
# Rain impact (people stay home)
|
||||
if precipitation > 2.0:
|
||||
adjustment_factor *= 0.7 # 30% reduction in heavy rain
|
||||
elif precipitation > 0.1:
|
||||
adjustment_factor *= 0.9 # 10% reduction in light rain
|
||||
|
||||
# Temperature impact
|
||||
if temperature < 5 or temperature > 35:
|
||||
adjustment_factor *= 0.8 # Extreme temperatures reduce foot traffic
|
||||
elif 18 <= temperature <= 25:
|
||||
adjustment_factor *= 1.1 # Pleasant weather increases activity
|
||||
|
||||
# Apply adjustments
|
||||
adjusted_prediction = max(0, base_prediction * adjustment_factor)
|
||||
adjusted_lower = max(0, lower_bound * adjustment_factor)
|
||||
adjusted_upper = max(0, upper_bound * adjustment_factor)
|
||||
|
||||
return {
|
||||
"prediction": adjusted_prediction,
|
||||
"lower_bound": adjusted_lower,
|
||||
"upper_bound": adjusted_upper,
|
||||
"confidence_interval": adjusted_upper - adjusted_lower,
|
||||
"confidence_level": prediction["confidence_level"],
|
||||
"adjustment_factor": adjustment_factor
|
||||
}
|
||||
|
||||
async def _save_forecast(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
request: ForecastRequest,
|
||||
prediction: Dict[str, float],
|
||||
model_data: Dict[str, Any],
|
||||
features: Dict[str, Any]
|
||||
) -> Forecast:
|
||||
"""Save forecast to database"""
|
||||
|
||||
forecast = Forecast(
|
||||
tenant_id=tenant_id,
|
||||
forecast_date=request.forecast_date,
|
||||
product_name=request.product_name,
|
||||
predicted_quantity=prediction["prediction"],
|
||||
confidence_level=request.confidence_level,
|
||||
lower_bound=prediction["lower_bound"],
|
||||
upper_bound=prediction["upper_bound"],
|
||||
model_id=model_data["model_id"],
|
||||
external_factors=features,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
db.add(forecast)
|
||||
await db.commit()
|
||||
await db.refresh(forecast)
|
||||
|
||||
return forecast
|
||||
|
||||
async def get_forecast_history(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: Optional[str] = None,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
db: AsyncSession = None
|
||||
) -> List[Forecast]:
|
||||
"""Retrieve forecast history with filters"""
|
||||
|
||||
try:
|
||||
query = select(Forecast).where(Forecast.tenant_id == tenant_id)
|
||||
|
||||
if product_name:
|
||||
query = query.where(Forecast.product_name == product_name)
|
||||
|
||||
if start_date:
|
||||
query = query.where(Forecast.forecast_date >= start_date)
|
||||
|
||||
if end_date:
|
||||
query = query.where(Forecast.forecast_date <= end_date)
|
||||
|
||||
query = query.order_by(desc(Forecast.forecast_date))
|
||||
|
||||
result = await db.execute(query)
|
||||
@@ -244,129 +450,4 @@ class ForecastingService:
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving forecasts", error=str(e))
|
||||
raise
|
||||
|
||||
async def _get_latest_model(self, tenant_id: str, product_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get the latest trained model for a tenant/product combination"""
|
||||
try:
|
||||
# Pass the product_name to the model client
|
||||
model_data = await self.model_client.get_best_model_for_forecasting(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name # Make sure to pass product_name
|
||||
)
|
||||
return model_data
|
||||
except Exception as e:
|
||||
logger.error("Error getting latest model", error=str(e))
|
||||
raise
|
||||
|
||||
async def _prepare_forecast_features(self, tenant_id: str, request: ForecastRequest) -> Dict[str, Any]:
|
||||
"""Prepare features for forecasting model"""
|
||||
|
||||
features = {
|
||||
"date": request.forecast_date.isoformat(),
|
||||
"day_of_week": request.forecast_date.weekday(),
|
||||
"is_weekend": request.forecast_date.weekday() >= 5
|
||||
}
|
||||
|
||||
# Add Spanish holidays
|
||||
features["is_holiday"] = self._is_spanish_holiday(request.forecast_date)
|
||||
|
||||
|
||||
weather_data = await self._get_weather_forecast(tenant_id, 1)
|
||||
features.update(weather_data)
|
||||
|
||||
return features
|
||||
|
||||
def _is_spanish_holiday(self, date: datetime) -> bool:
|
||||
"""Check if a date is a major Spanish holiday"""
|
||||
month_day = (date.month, date.day)
|
||||
|
||||
# Major Spanish holidays that affect bakery sales
|
||||
spanish_holidays = [
|
||||
(1, 1), # New Year
|
||||
(1, 6), # Epiphany (Reyes)
|
||||
(5, 1), # Labour Day
|
||||
(8, 15), # Assumption
|
||||
(10, 12), # National Day
|
||||
(11, 1), # All Saints
|
||||
(12, 6), # Constitution
|
||||
(12, 8), # Immaculate Conception
|
||||
(12, 25), # Christmas
|
||||
(5, 15), # San Isidro (Madrid patron saint)
|
||||
(5, 2), # Madrid Community Day
|
||||
]
|
||||
|
||||
return month_day in spanish_holidays
|
||||
|
||||
async def _get_weather_forecast(self, tenant_id: str, days: str) -> Dict[str, Any]:
|
||||
"""Get weather forecast for the date"""
|
||||
|
||||
try:
|
||||
weather_data = await self.data_client.fetch_weather_forecast(tenant_id, days)
|
||||
return weather_data
|
||||
except Exception as e:
|
||||
logger.warning("Error getting weather forecast", error=str(e))
|
||||
return {}
|
||||
|
||||
async def _check_and_create_alerts(self, forecast: Forecast, db: AsyncSession):
|
||||
"""Check forecast and create alerts if needed"""
|
||||
|
||||
try:
|
||||
alerts_to_create = []
|
||||
|
||||
# High demand alert
|
||||
if forecast.predicted_demand > settings.HIGH_DEMAND_THRESHOLD * 100: # Assuming base of 100 units
|
||||
alerts_to_create.append({
|
||||
"type": "high_demand",
|
||||
"severity": "medium",
|
||||
"message": f"High demand predicted for {forecast.product_name}: {forecast.predicted_demand:.0f} units"
|
||||
})
|
||||
|
||||
# Low demand alert
|
||||
if forecast.predicted_demand < settings.LOW_DEMAND_THRESHOLD * 100:
|
||||
alerts_to_create.append({
|
||||
"type": "low_demand",
|
||||
"severity": "low",
|
||||
"message": f"Low demand predicted for {forecast.product_name}: {forecast.predicted_demand:.0f} units"
|
||||
})
|
||||
|
||||
# Stockout risk alert
|
||||
if forecast.confidence_upper > settings.STOCKOUT_RISK_THRESHOLD * forecast.predicted_demand:
|
||||
alerts_to_create.append({
|
||||
"type": "stockout_risk",
|
||||
"severity": "high",
|
||||
"message": f"Stockout risk for {forecast.product_name}. Upper confidence: {forecast.confidence_upper:.0f}"
|
||||
})
|
||||
|
||||
# Create alerts
|
||||
for alert_data in alerts_to_create:
|
||||
alert = ForecastAlert(
|
||||
tenant_id=forecast.tenant_id,
|
||||
forecast_id=forecast.id,
|
||||
alert_type=alert_data["type"],
|
||||
severity=alert_data["severity"],
|
||||
message=alert_data["message"]
|
||||
)
|
||||
|
||||
db.add(alert)
|
||||
|
||||
# Publish alert event
|
||||
await publish_alert_created({
|
||||
"alert_id": str(alert.id),
|
||||
"tenant_id": str(forecast.tenant_id),
|
||||
"product_name": forecast.product_name,
|
||||
"alert_type": alert_data["type"],
|
||||
"severity": alert_data["severity"],
|
||||
"message": alert_data["message"]
|
||||
})
|
||||
|
||||
await db.commit()
|
||||
|
||||
if alerts_to_create:
|
||||
logger.info("Created forecast alerts",
|
||||
forecast_id=str(forecast.id),
|
||||
alerts_count=len(alerts_to_create))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error creating alerts", error=str(e))
|
||||
# Don't raise - alerts are not critical for forecast generation
|
||||
raise
|
||||
Reference in New Issue
Block a user