Files
bakery-ia/services/forecasting/app/services/enterprise_forecasting_service.py
2025-12-05 20:07:01 +01:00

260 lines
11 KiB
Python

"""
Enterprise forecasting service for aggregated demand across parent-child tenants
"""
import logging
from typing import Dict, Any, List, Optional
from datetime import date, datetime
import json
import redis.asyncio as redis
from shared.clients.forecast_client import ForecastServiceClient
from shared.clients.tenant_client import TenantServiceClient
logger = logging.getLogger(__name__)
class EnterpriseForecastingService:
"""
Service for aggregating forecasts across parent and child tenants
"""
def __init__(
self,
forecast_client: ForecastServiceClient,
tenant_client: TenantServiceClient,
redis_client: redis.Redis
):
self.forecast_client = forecast_client
self.tenant_client = tenant_client
self.redis_client = redis_client
self.cache_ttl_seconds = 3600 # 1 hour TTL
async def get_aggregated_forecast(
self,
parent_tenant_id: str,
start_date: date,
end_date: date,
product_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Get aggregated forecast across parent and all child tenants
Args:
parent_tenant_id: Parent tenant ID
start_date: Start date for forecast aggregation
end_date: End date for forecast aggregation
product_id: Optional product ID to filter by
Returns:
Dict with aggregated forecast data by date and product
"""
# Create cache key
cache_key = f"agg_forecast:{parent_tenant_id}:{start_date}:{end_date}:{product_id or 'all'}"
# Try to get from cache first
try:
cached_result = await self.redis_client.get(cache_key)
if cached_result:
logger.info(f"Cache hit for aggregated forecast: {cache_key}")
return json.loads(cached_result)
except Exception as e:
logger.warning(f"Cache read failed: {e}")
logger.info(f"Computing aggregated forecast for parent {parent_tenant_id} from {start_date} to {end_date}")
# Get child tenant IDs
child_tenants = await self.tenant_client.get_child_tenants(parent_tenant_id)
child_tenant_ids = [child['id'] for child in child_tenants]
# Include parent tenant in the list for complete aggregation
all_tenant_ids = [parent_tenant_id] + child_tenant_ids
# Fetch forecasts for all tenants (parent + children)
all_forecasts = {}
tenant_contributions = {} # Track which tenant contributed to each forecast
for tenant_id in all_tenant_ids:
try:
tenant_forecasts = await self.forecast_client.get_forecasts(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_id=product_id
)
for forecast_date_str, products in tenant_forecasts.items():
if forecast_date_str not in all_forecasts:
all_forecasts[forecast_date_str] = {}
tenant_contributions[forecast_date_str] = {}
for product_id_key, forecast_data in products.items():
if product_id_key not in all_forecasts[forecast_date_str]:
all_forecasts[forecast_date_str][product_id_key] = {
'predicted_demand': 0,
'confidence_lower': 0,
'confidence_upper': 0,
'tenant_contributions': []
}
# Aggregate the forecast values
all_forecasts[forecast_date_str][product_id_key]['predicted_demand'] += forecast_data.get('predicted_demand', 0)
# For confidence intervals, we'll use a simple approach
# In a real implementation, this would require proper statistical combination
all_forecasts[forecast_date_str][product_id_key]['confidence_lower'] += forecast_data.get('confidence_lower', 0)
all_forecasts[forecast_date_str][product_id_key]['confidence_upper'] += forecast_data.get('confidence_upper', 0)
# Track contribution by tenant
all_forecasts[forecast_date_str][product_id_key]['tenant_contributions'].append({
'tenant_id': tenant_id,
'demand': forecast_data.get('predicted_demand', 0),
'confidence_lower': forecast_data.get('confidence_lower', 0),
'confidence_upper': forecast_data.get('confidence_upper', 0)
})
except Exception as e:
logger.error(f"Failed to fetch forecasts for tenant {tenant_id}: {e}")
# Continue with other tenants even if one fails
# Prepare result
result = {
"parent_tenant_id": parent_tenant_id,
"aggregated_forecasts": all_forecasts,
"tenant_contributions": tenant_contributions,
"child_tenant_count": len(child_tenant_ids),
"forecast_dates": list(all_forecasts.keys()),
"computed_at": datetime.utcnow().isoformat()
}
# Cache the result
try:
await self.redis_client.setex(
cache_key,
self.cache_ttl_seconds,
json.dumps(result, default=str) # Handle date serialization
)
logger.info(f"Forecast cached for {cache_key}")
except Exception as e:
logger.warning(f"Cache write failed: {e}")
return result
async def get_network_performance_metrics(
self,
parent_tenant_id: str,
start_date: date,
end_date: date
) -> Dict[str, Any]:
"""
Get aggregated performance metrics across the tenant network
Args:
parent_tenant_id: Parent tenant ID
start_date: Start date for metrics
end_date: End date for metrics
Returns:
Dict with aggregated performance metrics
"""
child_tenants = await self.tenant_client.get_child_tenants(parent_tenant_id)
child_tenant_ids = [child['id'] for child in child_tenants]
# Include parent tenant in the list for complete aggregation
all_tenant_ids = [parent_tenant_id] + child_tenant_ids
total_sales = 0
total_forecasted = 0
total_accuracy = 0
tenant_count = 0
performance_data = {}
for tenant_id in all_tenant_ids:
try:
# Fetch sales and forecast data for the period
sales_data = await self._fetch_sales_data(tenant_id, start_date, end_date)
forecast_data = await self.get_aggregated_forecast(tenant_id, start_date, end_date)
tenant_performance = {
'tenant_id': tenant_id,
'sales': sales_data.get('total_sales', 0),
'forecasted': sum(
sum(day.get('predicted_demand', 0) for product in day.values())
if isinstance(day, dict) else day
for day in forecast_data.get('aggregated_forecasts', {}).values()
),
}
# Calculate accuracy if both sales and forecast data exist
if tenant_performance['sales'] > 0 and tenant_performance['forecasted'] > 0:
accuracy = 1 - abs(tenant_performance['forecasted'] - tenant_performance['sales']) / tenant_performance['sales']
tenant_performance['accuracy'] = max(0, min(1, accuracy)) # Clamp between 0 and 1
else:
tenant_performance['accuracy'] = 0
performance_data[tenant_id] = tenant_performance
total_sales += tenant_performance['sales']
total_forecasted += tenant_performance['forecasted']
total_accuracy += tenant_performance['accuracy']
tenant_count += 1
except Exception as e:
logger.error(f"Failed to fetch performance data for tenant {tenant_id}: {e}")
network_performance = {
"parent_tenant_id": parent_tenant_id,
"total_sales": total_sales,
"total_forecasted": total_forecasted,
"average_accuracy": total_accuracy / tenant_count if tenant_count > 0 else 0,
"tenant_count": tenant_count,
"child_tenant_count": len(child_tenant_ids),
"tenant_performances": performance_data,
"computed_at": datetime.utcnow().isoformat()
}
return network_performance
async def _fetch_sales_data(self, tenant_id: str, start_date: date, end_date: date) -> Dict[str, Any]:
"""
Helper method to fetch sales data from the sales service
"""
try:
from shared.clients.sales_client import SalesServiceClient
from shared.config.base import get_settings
# Create sales client
config = get_settings()
sales_client = SalesServiceClient(config, calling_service_name="forecasting")
# Fetch sales data for the date range
sales_data = await sales_client.get_sales_data(
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat(),
aggregation="daily"
)
# Calculate total sales from the retrieved data
total_sales = 0
if sales_data:
for sale in sales_data:
# Sum up quantity_sold or total_amount depending on what's available
total_sales += sale.get('quantity_sold', 0)
return {
'total_sales': total_sales,
'date_range': f"{start_date} to {end_date}",
'tenant_id': tenant_id,
'record_count': len(sales_data) if sales_data else 0
}
except Exception as e:
logger.error(f"Failed to fetch sales data for tenant {tenant_id}: {e}")
# Return empty result on error
return {
'total_sales': 0,
'date_range': f"{start_date} to {end_date}",
'tenant_id': tenant_id,
'error': str(e)
}