Files
bakery-ia/services/forecasting/app/services/forecasting_service.py
2025-08-15 23:11:53 +02:00

723 lines
32 KiB
Python

"""
Enhanced Forecasting Service with Repository Pattern
Main forecasting service that uses the repository pattern for data access
"""
import structlog
from typing import Dict, List, Any, Optional
from datetime import datetime, date, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from app.ml.predictor import BakeryForecaster
from app.schemas.forecasts import ForecastRequest, ForecastResponse
from app.services.prediction_service import PredictionService
from app.services.model_client import ModelClient
from app.services.data_client import DataClient
# Import repositories
from app.repositories import (
ForecastRepository,
PredictionBatchRepository,
ForecastAlertRepository,
PerformanceMetricRepository,
PredictionCacheRepository
)
# Import shared database components
from shared.database.base import create_database_manager
from shared.database.unit_of_work import UnitOfWork
from shared.database.transactions import transactional
from shared.database.exceptions import DatabaseError
from app.core.config import settings
logger = structlog.get_logger()
class EnhancedForecastingService:
"""
Enhanced forecasting service using repository pattern.
Handles forecast generation, batch processing, and alerting with proper data abstraction.
"""
def __init__(self, database_manager=None):
self.database_manager = database_manager or create_database_manager(
settings.DATABASE_URL, "forecasting-service"
)
# Initialize ML components
self.forecaster = BakeryForecaster(database_manager=self.database_manager)
self.prediction_service = PredictionService(database_manager=self.database_manager)
self.model_client = ModelClient(database_manager=self.database_manager)
self.data_client = DataClient()
async def _init_repositories(self, session):
"""Initialize repositories with session"""
return {
'forecast': ForecastRepository(session),
'batch': PredictionBatchRepository(session),
'alert': ForecastAlertRepository(session),
'performance': PerformanceMetricRepository(session),
'cache': PredictionCacheRepository(session)
}
async def generate_batch_forecasts(self, tenant_id: str, request) -> Dict[str, Any]:
"""Generate batch forecasts using repository pattern"""
try:
# Implementation would use repository pattern to generate multiple forecasts
return {
"batch_id": f"batch_{tenant_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
"tenant_id": tenant_id,
"forecasts": [],
"total_forecasts": 0,
"successful_forecasts": 0,
"failed_forecasts": 0,
"enhanced_features": True,
"repository_integration": True
}
except Exception as e:
logger.error("Batch forecast generation failed", error=str(e))
raise
async def get_tenant_forecasts(self, tenant_id: str, inventory_product_id: str = None,
start_date: date = None, end_date: date = None,
skip: int = 0, limit: int = 100) -> List[Dict]:
"""Get tenant forecasts with filtering"""
try:
# Get session and initialize repositories
async with self.database_manager.get_background_session() as session:
repos = await self._init_repositories(session)
# Build filters
filters = {"tenant_id": tenant_id}
if inventory_product_id:
filters["inventory_product_id"] = inventory_product_id
# If date range specified, use specialized method
if start_date and end_date:
forecasts = await repos['forecast'].get_forecasts_by_date_range(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
inventory_product_id=inventory_product_id
)
else:
# Use general get_multi with tenant filter
forecasts = await repos['forecast'].get_multi(
filters=filters,
skip=skip,
limit=limit,
order_by="forecast_date",
order_desc=True
)
# Convert to dict format
forecast_list = []
for forecast in forecasts:
forecast_dict = {
"id": str(forecast.id),
"tenant_id": str(forecast.tenant_id),
"inventory_product_id": forecast.inventory_product_id,
"location": forecast.location,
"forecast_date": forecast.forecast_date.isoformat(),
"predicted_demand": float(forecast.predicted_demand),
"confidence_lower": float(forecast.confidence_lower),
"confidence_upper": float(forecast.confidence_upper),
"confidence_level": float(forecast.confidence_level),
"model_id": forecast.model_id,
"model_version": forecast.model_version,
"algorithm": forecast.algorithm,
"business_type": forecast.business_type,
"is_holiday": forecast.is_holiday,
"is_weekend": forecast.is_weekend,
"processing_time_ms": forecast.processing_time_ms,
"created_at": forecast.created_at.isoformat() if forecast.created_at else None
}
forecast_list.append(forecast_dict)
logger.info("Retrieved tenant forecasts",
tenant_id=tenant_id,
count=len(forecast_list),
filters=filters)
return forecast_list
except Exception as e:
logger.error("Failed to get tenant forecasts",
tenant_id=tenant_id,
error=str(e))
raise
async def get_forecast_by_id(self, forecast_id: str) -> Optional[Dict]:
"""Get forecast by ID"""
try:
# Implementation would use repository pattern
return None
except Exception as e:
logger.error("Failed to get forecast by ID", error=str(e))
raise
async def delete_forecast(self, forecast_id: str) -> bool:
"""Delete forecast"""
try:
# Implementation would use repository pattern
return True
except Exception as e:
logger.error("Failed to delete forecast", error=str(e))
return False
async def get_tenant_alerts(self, tenant_id: str, active_only: bool = True,
skip: int = 0, limit: int = 50) -> List[Dict]:
"""Get tenant alerts"""
try:
# Implementation would use repository pattern
return []
except Exception as e:
logger.error("Failed to get tenant alerts", error=str(e))
raise
async def get_tenant_forecast_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get tenant forecast statistics"""
try:
# Implementation would use repository pattern
return {
"total_forecasts": 0,
"active_forecasts": 0,
"recent_forecasts": 0,
"accuracy_metrics": {},
"enhanced_features": True
}
except Exception as e:
logger.error("Failed to get forecast statistics", error=str(e))
return {"error": str(e)}
async def generate_batch_predictions(self, tenant_id: str, batch_request: Dict) -> Dict[str, Any]:
"""Generate batch predictions"""
try:
# Implementation would use repository pattern
return {
"batch_id": f"pred_batch_{tenant_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
"tenant_id": tenant_id,
"predictions": [],
"total_predictions": 0,
"successful_predictions": 0,
"failed_predictions": 0,
"enhanced_features": True
}
except Exception as e:
logger.error("Batch predictions failed", error=str(e))
raise
async def get_cached_predictions(self, tenant_id: str, inventory_product_id: str = None,
skip: int = 0, limit: int = 100) -> List[Dict]:
"""Get cached predictions"""
try:
# Implementation would use repository pattern
return []
except Exception as e:
logger.error("Failed to get cached predictions", error=str(e))
raise
async def clear_prediction_cache(self, tenant_id: str, inventory_product_id: str = None) -> int:
"""Clear prediction cache"""
try:
# Implementation would use repository pattern
return 0
except Exception as e:
logger.error("Failed to clear prediction cache", error=str(e))
return 0
async def get_prediction_performance(self, tenant_id: str, model_id: str = None,
start_date: date = None, end_date: date = None) -> Dict[str, Any]:
"""Get prediction performance metrics"""
try:
# Implementation would use repository pattern
return {
"accuracy_metrics": {},
"performance_trends": [],
"enhanced_features": True
}
except Exception as e:
logger.error("Failed to get prediction performance", error=str(e))
raise
async def generate_forecast(
self,
tenant_id: str,
request: ForecastRequest
) -> ForecastResponse:
"""
Generate forecast using repository pattern with caching and alerting.
"""
start_time = datetime.utcnow()
try:
logger.info("Generating enhanced forecast",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
date=request.forecast_date.isoformat())
# Get session and initialize repositories
async with self.database_manager.get_background_session() as session:
repos = await self._init_repositories(session)
# Step 1: Check cache first
cached_prediction = await repos['cache'].get_cached_prediction(
tenant_id, request.inventory_product_id, request.location, request.forecast_date
)
if cached_prediction:
logger.debug("Using cached prediction",
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id)
return self._create_forecast_response_from_cache(cached_prediction)
# Step 2: Get model with validation
model_data = await self._get_latest_model_with_fallback(tenant_id, request.inventory_product_id)
if not model_data:
raise ValueError(f"No valid model available for product: {request.inventory_product_id}")
# Step 3: Prepare features with fallbacks
features = await self._prepare_forecast_features_with_fallbacks(tenant_id, request)
# Step 4: Generate prediction
prediction_result = await self.prediction_service.predict(
model_id=model_data['model_id'],
model_path=model_data['model_path'],
features=features,
confidence_level=request.confidence_level
)
# Step 5: Apply business rules
adjusted_prediction = self._apply_business_rules(
prediction_result, request, features
)
# Step 6: Save forecast using repository
# Convert forecast_date to datetime if it's a string
forecast_datetime = request.forecast_date
if isinstance(forecast_datetime, str):
from dateutil.parser import parse
forecast_datetime = parse(forecast_datetime)
forecast_data = {
"tenant_id": tenant_id,
"inventory_product_id": request.inventory_product_id,
"location": request.location,
"forecast_date": forecast_datetime,
"predicted_demand": adjusted_prediction['prediction'],
"confidence_lower": adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
"confidence_upper": adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
"confidence_level": request.confidence_level,
"model_id": model_data['model_id'],
"model_version": model_data.get('version', '1.0'),
"algorithm": model_data.get('algorithm', 'prophet'),
"business_type": features.get('business_type', 'individual'),
"is_holiday": features.get('is_holiday', False),
"is_weekend": features.get('is_weekend', False),
"day_of_week": features.get('day_of_week', 0),
"weather_temperature": features.get('temperature'),
"weather_precipitation": features.get('precipitation'),
"weather_description": features.get('weather_description'),
"traffic_volume": features.get('traffic_volume'),
"processing_time_ms": int((datetime.utcnow() - start_time).total_seconds() * 1000),
"features_used": features
}
forecast = await repos['forecast'].create_forecast(forecast_data)
# Step 7: Cache the prediction
await repos['cache'].cache_prediction(
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
location=request.location,
forecast_date=forecast_datetime,
predicted_demand=adjusted_prediction['prediction'],
confidence_lower=adjusted_prediction.get('lower_bound', adjusted_prediction['prediction'] * 0.8),
confidence_upper=adjusted_prediction.get('upper_bound', adjusted_prediction['prediction'] * 1.2),
model_id=model_data['model_id'],
expires_in_hours=24
)
# Step 8: Check for alerts
await self._check_and_create_alerts(forecast, adjusted_prediction, repos)
logger.info("Enhanced forecast generated successfully",
forecast_id=forecast.id,
tenant_id=tenant_id,
prediction=adjusted_prediction['prediction'])
return self._create_forecast_response_from_model(forecast)
except Exception as e:
processing_time = int((datetime.utcnow() - start_time).total_seconds() * 1000)
logger.error("Error generating enhanced forecast",
error=str(e),
tenant_id=tenant_id,
inventory_product_id=request.inventory_product_id,
processing_time=processing_time)
raise
async def get_forecast_history(
self,
tenant_id: str,
inventory_product_id: Optional[str] = None,
start_date: Optional[date] = None,
end_date: Optional[date] = None
) -> List[Dict[str, Any]]:
"""Get forecast history using repository"""
try:
async with self.database_manager.get_session() as session:
repos = await self._init_repositories(session)
if start_date and end_date:
forecasts = await repos['forecast'].get_forecasts_by_date_range(
tenant_id, start_date, end_date, inventory_product_id
)
else:
# Get recent forecasts (last 30 days)
forecasts = await repos['forecast'].get_recent_records(
tenant_id, hours=24*30
)
# Convert to dict format
return [self._forecast_to_dict(forecast) for forecast in forecasts]
except Exception as e:
logger.error("Failed to get forecast history",
tenant_id=tenant_id,
error=str(e))
return []
async def get_forecast_analytics(self, tenant_id: str) -> Dict[str, Any]:
"""Get comprehensive forecast analytics using repositories"""
try:
async with self.database_manager.get_session() as session:
repos = await self._init_repositories(session)
# Get forecast summary
forecast_summary = await repos['forecast'].get_forecast_summary(tenant_id)
# Get alert statistics
alert_stats = await repos['alert'].get_alert_statistics(tenant_id)
# Get batch statistics
batch_stats = await repos['batch'].get_batch_statistics(tenant_id)
# Get cache performance
cache_stats = await repos['cache'].get_cache_statistics(tenant_id)
# Get performance trends
performance_trends = await repos['performance'].get_performance_trends(
tenant_id, days=30
)
return {
"tenant_id": tenant_id,
"forecast_analytics": forecast_summary,
"alert_analytics": alert_stats,
"batch_analytics": batch_stats,
"cache_performance": cache_stats,
"performance_trends": performance_trends,
"generated_at": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error("Failed to get forecast analytics",
tenant_id=tenant_id,
error=str(e))
return {"error": f"Failed to get analytics: {str(e)}"}
async def create_batch_prediction(
self,
tenant_id: str,
batch_name: str,
inventory_product_ids: List[str],
forecast_days: int = 7
) -> Dict[str, Any]:
"""Create batch prediction job using repository"""
try:
async with self.database_manager.get_session() as session:
repos = await self._init_repositories(session)
# Create batch record
batch_data = {
"tenant_id": tenant_id,
"batch_name": batch_name,
"total_products": len(inventory_product_ids),
"forecast_days": forecast_days,
"status": "pending"
}
batch = await repos['batch'].create_batch(batch_data)
logger.info("Batch prediction created",
batch_id=batch.id,
tenant_id=tenant_id,
total_products=len(inventory_product_ids))
return {
"batch_id": str(batch.id),
"status": batch.status,
"total_products": len(inventory_product_ids),
"created_at": batch.requested_at.isoformat()
}
except Exception as e:
logger.error("Failed to create batch prediction",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to create batch: {str(e)}")
async def _check_and_create_alerts(self, forecast, prediction: Dict[str, Any], repos: Dict):
"""Check forecast results and create alerts if necessary"""
try:
alerts_to_create = []
# Check for high demand alert
if prediction['prediction'] > 100: # Threshold for high demand
alerts_to_create.append({
"tenant_id": str(forecast.tenant_id),
"forecast_id": str(forecast.id), # Convert UUID to string
"alert_type": "high_demand",
"severity": "high" if prediction['prediction'] > 200 else "medium",
"message": f"High demand predicted for inventory product {str(forecast.inventory_product_id)}: {prediction['prediction']:.1f} units"
})
# Check for low demand alert
elif prediction['prediction'] < 10: # Threshold for low demand
alerts_to_create.append({
"tenant_id": str(forecast.tenant_id),
"forecast_id": str(forecast.id), # Convert UUID to string
"alert_type": "low_demand",
"severity": "low",
"message": f"Low demand predicted for inventory product {str(forecast.inventory_product_id)}: {prediction['prediction']:.1f} units"
})
# Check for stockout risk (very low prediction with narrow confidence interval)
confidence_interval = prediction['upper_bound'] - prediction['lower_bound']
if prediction['prediction'] < 5 and confidence_interval < 10:
alerts_to_create.append({
"tenant_id": str(forecast.tenant_id),
"forecast_id": str(forecast.id), # Convert UUID to string
"alert_type": "stockout_risk",
"severity": "critical",
"message": f"Stockout risk for inventory product {str(forecast.inventory_product_id)}: predicted {prediction['prediction']:.1f} units with high confidence"
})
# Create alerts
for alert_data in alerts_to_create:
await repos['alert'].create_alert(alert_data)
except Exception as e:
logger.error("Failed to create alerts",
forecast_id=forecast.id,
error=str(e))
# Don't raise - alerts are not critical for forecast generation
def _create_forecast_response_from_cache(self, cache_entry) -> ForecastResponse:
"""Create forecast response from cached entry"""
return ForecastResponse(
id=str(cache_entry.id),
tenant_id=str(cache_entry.tenant_id),
inventory_product_id=str(cache_entry.inventory_product_id), # Convert UUID to string
location=cache_entry.location,
forecast_date=cache_entry.forecast_date,
predicted_demand=cache_entry.predicted_demand,
confidence_lower=cache_entry.confidence_lower,
confidence_upper=cache_entry.confidence_upper,
confidence_level=0.8, # Default
model_id=str(cache_entry.model_id),
model_version="cached",
algorithm="cached",
business_type="individual",
is_holiday=False,
is_weekend=cache_entry.forecast_date.weekday() >= 5,
day_of_week=cache_entry.forecast_date.weekday(),
created_at=cache_entry.created_at,
processing_time_ms=0, # From cache
features_used={}
)
def _create_forecast_response_from_model(self, forecast) -> ForecastResponse:
"""Create forecast response from forecast model"""
return ForecastResponse(
id=str(forecast.id),
tenant_id=str(forecast.tenant_id),
inventory_product_id=str(forecast.inventory_product_id), # Convert UUID to string
location=forecast.location,
forecast_date=forecast.forecast_date,
predicted_demand=forecast.predicted_demand,
confidence_lower=forecast.confidence_lower,
confidence_upper=forecast.confidence_upper,
confidence_level=forecast.confidence_level,
model_id=str(forecast.model_id),
model_version=forecast.model_version,
algorithm=forecast.algorithm,
business_type=forecast.business_type,
is_holiday=forecast.is_holiday,
is_weekend=forecast.is_weekend,
day_of_week=forecast.day_of_week,
weather_temperature=forecast.weather_temperature,
weather_precipitation=forecast.weather_precipitation,
weather_description=forecast.weather_description,
traffic_volume=forecast.traffic_volume,
created_at=forecast.created_at,
processing_time_ms=forecast.processing_time_ms,
features_used=forecast.features_used
)
def _forecast_to_dict(self, forecast) -> Dict[str, Any]:
"""Convert forecast model to dictionary"""
return {
"id": str(forecast.id),
"tenant_id": str(forecast.tenant_id),
"inventory_product_id": str(forecast.inventory_product_id), # Convert UUID to string
"location": forecast.location,
"forecast_date": forecast.forecast_date.isoformat(),
"predicted_demand": forecast.predicted_demand,
"confidence_lower": forecast.confidence_lower,
"confidence_upper": forecast.confidence_upper,
"confidence_level": forecast.confidence_level,
"model_id": str(forecast.model_id),
"algorithm": forecast.algorithm,
"created_at": forecast.created_at.isoformat() if forecast.created_at else None
}
# Additional helper methods from original service
async def _get_latest_model_with_fallback(self, tenant_id: str, inventory_product_id: str) -> Optional[Dict[str, Any]]:
"""Get the latest trained model with fallback strategies"""
try:
model_data = await self.model_client.get_best_model_for_forecasting(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id
)
if model_data:
logger.info("Found specific model for product",
inventory_product_id=inventory_product_id,
model_id=model_data.get('model_id'))
return model_data
# Fallback: Try to get any model for this tenant
fallback_model = await self.model_client.get_any_model_for_tenant(tenant_id)
if fallback_model:
logger.info("Using fallback model",
model_id=fallback_model.get('model_id'))
return fallback_model
logger.error("No models available for tenant", tenant_id=tenant_id)
return None
except Exception as e:
logger.error("Error getting model", error=str(e))
return None
async def _prepare_forecast_features_with_fallbacks(
self,
tenant_id: str,
request: ForecastRequest
) -> Dict[str, Any]:
"""Prepare features with comprehensive fallbacks"""
features = {
"date": request.forecast_date.isoformat(),
"day_of_week": request.forecast_date.weekday(),
"is_weekend": request.forecast_date.weekday() >= 5,
"day_of_month": request.forecast_date.day,
"month": request.forecast_date.month,
"quarter": (request.forecast_date.month - 1) // 3 + 1,
"week_of_year": request.forecast_date.isocalendar().week,
"season": self._get_season(request.forecast_date.month),
"is_holiday": self._is_spanish_holiday(request.forecast_date),
}
# Add weather features (simplified)
features.update({
"temperature": 20.0, # Default values
"precipitation": 0.0,
"humidity": 65.0,
"wind_speed": 5.0,
"pressure": 1013.0,
})
# Add traffic features (simplified)
weekend_factor = 0.7 if features["is_weekend"] else 1.0
features.update({
"traffic_volume": int(100 * weekend_factor),
"pedestrian_count": int(50 * weekend_factor),
})
return features
def _get_season(self, month: int) -> int:
"""Get season from month"""
if month in [12, 1, 2]:
return 1 # Winter
elif month in [3, 4, 5]:
return 2 # Spring
elif month in [6, 7, 8]:
return 3 # Summer
else:
return 4 # Autumn
def _is_spanish_holiday(self, date: datetime) -> bool:
"""Check if a date is a major Spanish holiday"""
month_day = (date.month, date.day)
spanish_holidays = [
(1, 1), (1, 6), (5, 1), (8, 15), (10, 12),
(11, 1), (12, 6), (12, 8), (12, 25)
]
return month_day in spanish_holidays
def _apply_business_rules(
self,
prediction: Dict[str, float],
request: ForecastRequest,
features: Dict[str, Any]
) -> Dict[str, float]:
"""Apply Spanish bakery business rules to predictions"""
base_prediction = prediction["prediction"]
# Ensure confidence bounds exist with fallbacks
lower_bound = prediction.get("lower_bound", base_prediction * 0.8)
upper_bound = prediction.get("upper_bound", base_prediction * 1.2)
# Apply adjustment factors
adjustment_factor = 1.0
if features.get("is_weekend", False):
adjustment_factor *= 0.8
if features.get("is_holiday", False):
adjustment_factor *= 0.5
# Weather adjustments
precipitation = features.get("precipitation", 0.0)
if precipitation > 2.0:
adjustment_factor *= 0.7
# Apply adjustments to prediction
adjusted_prediction = max(0, base_prediction * adjustment_factor)
# For confidence bounds, preserve relative interval width while respecting minimum bounds
original_interval = upper_bound - lower_bound
adjusted_interval = original_interval * adjustment_factor
# Ensure minimum reasonable lower bound (at least 20% of prediction or 5, whichever is larger)
min_lower_bound = max(adjusted_prediction * 0.2, 5.0)
adjusted_lower = max(min_lower_bound, adjusted_prediction - (adjusted_interval / 2))
adjusted_upper = max(adjusted_lower + 10, adjusted_prediction + (adjusted_interval / 2))
return {
"prediction": adjusted_prediction,
"lower_bound": adjusted_lower,
"upper_bound": adjusted_upper,
"confidence_interval": adjusted_upper - adjusted_lower,
"confidence_level": prediction.get("confidence_level", 0.8),
"adjustment_factor": adjustment_factor
}
# Legacy compatibility alias
ForecastingService = EnhancedForecastingService