New enterprise feature

This commit is contained in:
Urtzi Alfaro
2025-11-30 09:12:40 +01:00
parent f9d0eec6ec
commit 972db02f6d
176 changed files with 19741 additions and 1361 deletions

View File

@@ -11,6 +11,7 @@ from .historical_validation import router as historical_validation_router
from .webhooks import router as webhooks_router
from .performance_monitoring import router as performance_monitoring_router
from .retraining import router as retraining_router
from .enterprise_forecasting import router as enterprise_forecasting_router
__all__ = [
@@ -22,4 +23,5 @@ __all__ = [
"webhooks_router",
"performance_monitoring_router",
"retraining_router",
"enterprise_forecasting_router",
]

View File

@@ -0,0 +1,108 @@
"""
Enterprise forecasting API endpoints
"""
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Optional
from datetime import date
import structlog
from app.services.enterprise_forecasting_service import EnterpriseForecastingService
from shared.auth.tenant_access import verify_tenant_permission_dep
from shared.clients import get_forecast_client, get_tenant_client
import shared.redis_utils
from app.core.config import settings
logger = structlog.get_logger()
router = APIRouter()
# Global Redis client
_redis_client = None
async def get_forecasting_redis_client():
"""Get or create Redis client"""
global _redis_client
try:
if _redis_client is None:
_redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
logger.info("Redis client initialized for enterprise forecasting")
return _redis_client
except Exception as e:
logger.warning("Failed to initialize Redis client, enterprise forecasting will work with limited functionality", error=str(e))
return None
async def get_enterprise_forecasting_service(
redis_client = Depends(get_forecasting_redis_client)
) -> EnterpriseForecastingService:
"""Dependency injection for EnterpriseForecastingService"""
forecast_client = get_forecast_client(settings, "forecasting-service")
tenant_client = get_tenant_client(settings, "forecasting-service")
return EnterpriseForecastingService(
forecast_client=forecast_client,
tenant_client=tenant_client,
redis_client=redis_client
)
@router.get("/tenants/{tenant_id}/forecasting/enterprise/aggregated")
async def get_aggregated_forecast(
tenant_id: str,
start_date: date = Query(..., description="Start date for forecast aggregation"),
end_date: date = Query(..., description="End date for forecast aggregation"),
product_id: Optional[str] = Query(None, description="Optional product ID to filter by"),
enterprise_forecasting_service: EnterpriseForecastingService = Depends(get_enterprise_forecasting_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Get aggregated forecasts across parent and child tenants
"""
try:
# Check if this tenant is a parent tenant
tenant_info = await enterprise_forecasting_service.tenant_client.get_tenant(tenant_id)
if tenant_info.get('tenant_type') != 'parent':
raise HTTPException(
status_code=403,
detail="Only parent tenants can access aggregated enterprise forecasts"
)
result = await enterprise_forecasting_service.get_aggregated_forecast(
parent_tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_id=product_id
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get aggregated forecast: {str(e)}")
@router.get("/tenants/{tenant_id}/forecasting/enterprise/network-performance")
async def get_network_performance_metrics(
tenant_id: str,
start_date: date = Query(..., description="Start date for metrics"),
end_date: date = Query(..., description="End date for metrics"),
enterprise_forecasting_service: EnterpriseForecastingService = Depends(get_enterprise_forecasting_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Get aggregated performance metrics across tenant network
"""
try:
# Check if this tenant is a parent tenant
tenant_info = await enterprise_forecasting_service.tenant_client.get_tenant(tenant_id)
if tenant_info.get('tenant_type') != 'parent':
raise HTTPException(
status_code=403,
detail="Only parent tenants can access network performance metrics"
)
result = await enterprise_forecasting_service.get_network_performance_metrics(
parent_tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get network performance: {str(e)}")

View File

@@ -23,17 +23,14 @@ from app.models.forecasts import Forecast, PredictionBatch
logger = structlog.get_logger()
router = APIRouter(prefix="/internal/demo", tags=["internal"])
# Internal API key for service-to-service auth
INTERNAL_API_KEY = os.getenv("INTERNAL_API_KEY", "dev-internal-key-change-in-production")
# Base demo tenant IDs
DEMO_TENANT_SAN_PABLO = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
DEMO_TENANT_LA_ESPIGA = "b2c3d4e5-f6a7-48b9-c0d1-e2f3a4b5c6d7"
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
def verify_internal_api_key(x_internal_api_key: Optional[str] = Header(None)):
"""Verify internal API key for service-to-service communication"""
if x_internal_api_key != INTERNAL_API_KEY:
from app.core.config import settings
if x_internal_api_key != settings.INTERNAL_API_KEY:
logger.warning("Unauthorized internal API access attempted")
raise HTTPException(status_code=403, detail="Invalid internal API key")
return True

View File

@@ -0,0 +1,178 @@
"""
Forecast event consumer for the forecasting service
Handles events that should trigger cache invalidation for aggregated forecasts
"""
import logging
from typing import Dict, Any, Optional
import json
import redis.asyncio as redis
logger = logging.getLogger(__name__)
class ForecastEventConsumer:
"""
Consumer for forecast events that may trigger cache invalidation
"""
def __init__(self, redis_client: redis.Redis):
self.redis_client = redis_client
async def handle_forecast_updated(self, event_data: Dict[str, Any]):
"""
Handle forecast updated event
Invalidate parent tenant's aggregated forecast cache if this tenant is a child
"""
try:
logger.info(f"Handling forecast updated event: {event_data}")
tenant_id = event_data.get('tenant_id')
forecast_date = event_data.get('forecast_date')
product_id = event_data.get('product_id')
updated_at = event_data.get('updated_at', None)
if not tenant_id:
logger.error("Missing tenant_id in forecast event")
return
# Check if this tenant is a child tenant (has parent)
# In a real implementation, this would call the tenant service to check hierarchy
parent_tenant_id = await self._get_parent_tenant_id(tenant_id)
if parent_tenant_id:
# Invalidate parent's aggregated forecast cache
await self._invalidate_parent_aggregated_cache(
parent_tenant_id=parent_tenant_id,
child_tenant_id=tenant_id,
forecast_date=forecast_date,
product_id=product_id
)
logger.info(f"Forecast updated event processed for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error handling forecast updated event: {e}", exc_info=True)
raise
async def handle_forecast_created(self, event_data: Dict[str, Any]):
"""
Handle forecast created event
Similar to update, may affect parent tenant's aggregated forecasts
"""
await self.handle_forecast_updated(event_data)
async def handle_forecast_deleted(self, event_data: Dict[str, Any]):
"""
Handle forecast deleted event
Similar to update, may affect parent tenant's aggregated forecasts
"""
try:
logger.info(f"Handling forecast deleted event: {event_data}")
tenant_id = event_data.get('tenant_id')
forecast_date = event_data.get('forecast_date')
product_id = event_data.get('product_id')
if not tenant_id:
logger.error("Missing tenant_id in forecast delete event")
return
# Check if this tenant is a child tenant
parent_tenant_id = await self._get_parent_tenant_id(tenant_id)
if parent_tenant_id:
# Invalidate parent's aggregated forecast cache
await self._invalidate_parent_aggregated_cache(
parent_tenant_id=parent_tenant_id,
child_tenant_id=tenant_id,
forecast_date=forecast_date,
product_id=product_id
)
logger.info(f"Forecast deleted event processed for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error handling forecast deleted event: {e}", exc_info=True)
raise
async def _get_parent_tenant_id(self, tenant_id: str) -> Optional[str]:
"""
Get parent tenant ID for a child tenant
In a real implementation, this would call the tenant service
"""
# This is a placeholder implementation
# In real implementation, this would use TenantServiceClient to get tenant hierarchy
try:
# Simulate checking tenant hierarchy
# In real implementation: return await self.tenant_client.get_parent_tenant_id(tenant_id)
# For now, we'll return a placeholder implementation that would check the database
# This is just a simulation of the actual implementation needed
return None # Placeholder - real implementation needed
except Exception as e:
logger.error(f"Error getting parent tenant ID for {tenant_id}: {e}")
return None
async def _invalidate_parent_aggregated_cache(
self,
parent_tenant_id: str,
child_tenant_id: str,
forecast_date: Optional[str] = None,
product_id: Optional[str] = None
):
"""
Invalidate parent tenant's aggregated forecast cache
"""
try:
# Pattern to match all aggregated forecast cache keys for this parent
# Format: agg_forecast:{parent_tenant_id}:{start_date}:{end_date}:{product_id}
pattern = f"agg_forecast:{parent_tenant_id}:*:*:*"
# Find all matching keys and delete them
keys_to_delete = []
async for key in self.redis_client.scan_iter(match=pattern):
if isinstance(key, bytes):
key = key.decode('utf-8')
keys_to_delete.append(key)
if keys_to_delete:
await self.redis_client.delete(*keys_to_delete)
logger.info(f"Invalidated {len(keys_to_delete)} aggregated forecast cache entries for parent tenant {parent_tenant_id}")
else:
logger.info(f"No aggregated forecast cache entries found to invalidate for parent tenant {parent_tenant_id}")
except Exception as e:
logger.error(f"Error invalidating parent aggregated cache: {e}", exc_info=True)
raise
async def handle_tenant_hierarchy_changed(self, event_data: Dict[str, Any]):
"""
Handle tenant hierarchy change event
This could be when a tenant becomes a child of another, or when the hierarchy changes
"""
try:
logger.info(f"Handling tenant hierarchy change event: {event_data}")
tenant_id = event_data.get('tenant_id')
parent_tenant_id = event_data.get('parent_tenant_id')
action = event_data.get('action') # 'added', 'removed', 'changed'
# Invalidate any cached aggregated forecasts that might be affected
if parent_tenant_id:
# If this child tenant changed, invalidate parent's cache
await self._invalidate_parent_aggregated_cache(
parent_tenant_id=parent_tenant_id,
child_tenant_id=tenant_id
)
# If this was a former parent tenant that's no longer a parent,
# its aggregated cache might need to be invalidated differently
if action == 'removed' and event_data.get('was_parent'):
# Invalidate its own aggregated cache since it's no longer a parent
# This would be handled by tenant service events
pass
except Exception as e:
logger.error(f"Error handling tenant hierarchy change event: {e}", exc_info=True)
raise

View File

@@ -15,7 +15,7 @@ from app.services.forecasting_alert_service import ForecastingAlertService
from shared.service_base import StandardFastAPIService
# Import API routers
from app.api import forecasts, forecasting_operations, analytics, scenario_operations, internal_demo, audit, ml_insights, validation, historical_validation, webhooks, performance_monitoring, retraining
from app.api import forecasts, forecasting_operations, analytics, scenario_operations, internal_demo, audit, ml_insights, validation, historical_validation, webhooks, performance_monitoring, retraining, enterprise_forecasting
class ForecastingService(StandardFastAPIService):
@@ -176,6 +176,7 @@ service.add_router(historical_validation.router) # Historical validation endpoi
service.add_router(webhooks.router) # Webhooks endpoint
service.add_router(performance_monitoring.router) # Performance monitoring endpoint
service.add_router(retraining.router) # Retraining endpoint
service.add_router(enterprise_forecasting.router) # Enterprise forecasting endpoint
if __name__ == "__main__":
import uvicorn

View File

@@ -0,0 +1,228 @@
"""
Enterprise forecasting service for aggregated demand across parent-child tenants
"""
import logging
from typing import Dict, Any, List, Optional
from datetime import date, datetime
import json
import redis.asyncio as redis
from shared.clients.forecast_client import ForecastServiceClient
from shared.clients.tenant_client import TenantServiceClient
logger = logging.getLogger(__name__)
class EnterpriseForecastingService:
"""
Service for aggregating forecasts across parent and child tenants
"""
def __init__(
self,
forecast_client: ForecastServiceClient,
tenant_client: TenantServiceClient,
redis_client: redis.Redis
):
self.forecast_client = forecast_client
self.tenant_client = tenant_client
self.redis_client = redis_client
self.cache_ttl_seconds = 3600 # 1 hour TTL
async def get_aggregated_forecast(
self,
parent_tenant_id: str,
start_date: date,
end_date: date,
product_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Get aggregated forecast across parent and all child tenants
Args:
parent_tenant_id: Parent tenant ID
start_date: Start date for forecast aggregation
end_date: End date for forecast aggregation
product_id: Optional product ID to filter by
Returns:
Dict with aggregated forecast data by date and product
"""
# Create cache key
cache_key = f"agg_forecast:{parent_tenant_id}:{start_date}:{end_date}:{product_id or 'all'}"
# Try to get from cache first
try:
cached_result = await self.redis_client.get(cache_key)
if cached_result:
logger.info(f"Cache hit for aggregated forecast: {cache_key}")
return json.loads(cached_result)
except Exception as e:
logger.warning(f"Cache read failed: {e}")
logger.info(f"Computing aggregated forecast for parent {parent_tenant_id} from {start_date} to {end_date}")
# Get child tenant IDs
child_tenants = await self.tenant_client.get_child_tenants(parent_tenant_id)
child_tenant_ids = [child['id'] for child in child_tenants]
# Include parent tenant in the list for complete aggregation
all_tenant_ids = [parent_tenant_id] + child_tenant_ids
# Fetch forecasts for all tenants (parent + children)
all_forecasts = {}
tenant_contributions = {} # Track which tenant contributed to each forecast
for tenant_id in all_tenant_ids:
try:
tenant_forecasts = await self.forecast_client.get_forecasts(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_id=product_id
)
for forecast_date_str, products in tenant_forecasts.items():
if forecast_date_str not in all_forecasts:
all_forecasts[forecast_date_str] = {}
tenant_contributions[forecast_date_str] = {}
for product_id_key, forecast_data in products.items():
if product_id_key not in all_forecasts[forecast_date_str]:
all_forecasts[forecast_date_str][product_id_key] = {
'predicted_demand': 0,
'confidence_lower': 0,
'confidence_upper': 0,
'tenant_contributions': []
}
# Aggregate the forecast values
all_forecasts[forecast_date_str][product_id_key]['predicted_demand'] += forecast_data.get('predicted_demand', 0)
# For confidence intervals, we'll use a simple approach
# In a real implementation, this would require proper statistical combination
all_forecasts[forecast_date_str][product_id_key]['confidence_lower'] += forecast_data.get('confidence_lower', 0)
all_forecasts[forecast_date_str][product_id_key]['confidence_upper'] += forecast_data.get('confidence_upper', 0)
# Track contribution by tenant
all_forecasts[forecast_date_str][product_id_key]['tenant_contributions'].append({
'tenant_id': tenant_id,
'demand': forecast_data.get('predicted_demand', 0),
'confidence_lower': forecast_data.get('confidence_lower', 0),
'confidence_upper': forecast_data.get('confidence_upper', 0)
})
except Exception as e:
logger.error(f"Failed to fetch forecasts for tenant {tenant_id}: {e}")
# Continue with other tenants even if one fails
# Prepare result
result = {
"parent_tenant_id": parent_tenant_id,
"aggregated_forecasts": all_forecasts,
"tenant_contributions": tenant_contributions,
"child_tenant_count": len(child_tenant_ids),
"forecast_dates": list(all_forecasts.keys()),
"computed_at": datetime.utcnow().isoformat()
}
# Cache the result
try:
await self.redis_client.setex(
cache_key,
self.cache_ttl_seconds,
json.dumps(result, default=str) # Handle date serialization
)
logger.info(f"Forecast cached for {cache_key}")
except Exception as e:
logger.warning(f"Cache write failed: {e}")
return result
async def get_network_performance_metrics(
self,
parent_tenant_id: str,
start_date: date,
end_date: date
) -> Dict[str, Any]:
"""
Get aggregated performance metrics across the tenant network
Args:
parent_tenant_id: Parent tenant ID
start_date: Start date for metrics
end_date: End date for metrics
Returns:
Dict with aggregated performance metrics
"""
child_tenants = await self.tenant_client.get_child_tenants(parent_tenant_id)
child_tenant_ids = [child['id'] for child in child_tenants]
# Include parent tenant in the list for complete aggregation
all_tenant_ids = [parent_tenant_id] + child_tenant_ids
total_sales = 0
total_forecasted = 0
total_accuracy = 0
tenant_count = 0
performance_data = {}
for tenant_id in all_tenant_ids:
try:
# Fetch sales and forecast data for the period
sales_data = await self._fetch_sales_data(tenant_id, start_date, end_date)
forecast_data = await self.get_aggregated_forecast(tenant_id, start_date, end_date)
tenant_performance = {
'tenant_id': tenant_id,
'sales': sales_data.get('total_sales', 0),
'forecasted': sum(
sum(day.get('predicted_demand', 0) for product in day.values())
if isinstance(day, dict) else day
for day in forecast_data.get('aggregated_forecasts', {}).values()
),
}
# Calculate accuracy if both sales and forecast data exist
if tenant_performance['sales'] > 0 and tenant_performance['forecasted'] > 0:
accuracy = 1 - abs(tenant_performance['forecasted'] - tenant_performance['sales']) / tenant_performance['sales']
tenant_performance['accuracy'] = max(0, min(1, accuracy)) # Clamp between 0 and 1
else:
tenant_performance['accuracy'] = 0
performance_data[tenant_id] = tenant_performance
total_sales += tenant_performance['sales']
total_forecasted += tenant_performance['forecasted']
total_accuracy += tenant_performance['accuracy']
tenant_count += 1
except Exception as e:
logger.error(f"Failed to fetch performance data for tenant {tenant_id}: {e}")
network_performance = {
"parent_tenant_id": parent_tenant_id,
"total_sales": total_sales,
"total_forecasted": total_forecasted,
"average_accuracy": total_accuracy / tenant_count if tenant_count > 0 else 0,
"tenant_count": tenant_count,
"child_tenant_count": len(child_tenant_ids),
"tenant_performances": performance_data,
"computed_at": datetime.utcnow().isoformat()
}
return network_performance
async def _fetch_sales_data(self, tenant_id: str, start_date: date, end_date: date) -> Dict[str, Any]:
"""
Helper method to fetch sales data (in a real implementation, this would call the sales service)
"""
# This is a placeholder implementation
# In real implementation, this would call the sales service
return {
'total_sales': 0, # Placeholder - would come from sales service
'date_range': f"{start_date} to {end_date}",
'tenant_id': tenant_id
}