REFACTOR production scheduler
This commit is contained in:
@@ -11,6 +11,7 @@ import uuid
|
||||
|
||||
from app.services.forecasting_service import EnhancedForecastingService
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.services.forecast_cache import get_forecast_cache_service
|
||||
from app.schemas.forecasts import (
|
||||
ForecastRequest, ForecastResponse, BatchForecastRequest,
|
||||
BatchForecastResponse, MultiDayForecastResponse
|
||||
@@ -53,7 +54,7 @@ async def generate_single_forecast(
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
enhanced_forecasting_service: EnhancedForecastingService = Depends(get_enhanced_forecasting_service)
|
||||
):
|
||||
"""Generate a single product forecast"""
|
||||
"""Generate a single product forecast with caching support"""
|
||||
metrics = get_metrics_collector(request_obj)
|
||||
|
||||
try:
|
||||
@@ -65,11 +66,41 @@ async def generate_single_forecast(
|
||||
if metrics:
|
||||
metrics.increment_counter("single_forecasts_total")
|
||||
|
||||
# Initialize cache service
|
||||
cache_service = get_forecast_cache_service(settings.REDIS_URL)
|
||||
|
||||
# Check cache first
|
||||
cached_forecast = await cache_service.get_cached_forecast(
|
||||
tenant_id=uuid.UUID(tenant_id),
|
||||
product_id=uuid.UUID(request.inventory_product_id),
|
||||
forecast_date=request.forecast_date
|
||||
)
|
||||
|
||||
if cached_forecast:
|
||||
if metrics:
|
||||
metrics.increment_counter("forecast_cache_hits_total")
|
||||
logger.info("Returning cached forecast",
|
||||
tenant_id=tenant_id,
|
||||
forecast_id=cached_forecast.get('id'))
|
||||
return ForecastResponse(**cached_forecast)
|
||||
|
||||
# Cache miss - generate forecast
|
||||
if metrics:
|
||||
metrics.increment_counter("forecast_cache_misses_total")
|
||||
|
||||
forecast = await enhanced_forecasting_service.generate_forecast(
|
||||
tenant_id=tenant_id,
|
||||
request=request
|
||||
)
|
||||
|
||||
# Cache the result
|
||||
await cache_service.cache_forecast(
|
||||
tenant_id=uuid.UUID(tenant_id),
|
||||
product_id=uuid.UUID(request.inventory_product_id),
|
||||
forecast_date=request.forecast_date,
|
||||
forecast_data=forecast.dict()
|
||||
)
|
||||
|
||||
if metrics:
|
||||
metrics.increment_counter("single_forecasts_success_total")
|
||||
|
||||
|
||||
518
services/forecasting/app/services/forecast_cache.py
Normal file
518
services/forecasting/app/services/forecast_cache.py
Normal file
@@ -0,0 +1,518 @@
|
||||
# services/forecasting/app/services/forecast_cache.py
|
||||
"""
|
||||
Forecast Cache Service - Redis-based caching for forecast results
|
||||
|
||||
Provides service-level caching for forecast predictions to eliminate redundant
|
||||
computations when multiple services (Orders, Production) request the same
|
||||
forecast data within a short time window.
|
||||
|
||||
Cache Strategy:
|
||||
- Key: forecast:{tenant_id}:{product_id}:{forecast_date}
|
||||
- TTL: Until midnight of day after forecast_date
|
||||
- Invalidation: On model retraining for specific products
|
||||
- Metadata: Includes 'cached' flag for observability
|
||||
"""
|
||||
|
||||
import json
|
||||
import redis
|
||||
from datetime import datetime, date, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ForecastCacheService:
|
||||
"""Service-level caching for forecast predictions"""
|
||||
|
||||
def __init__(self, redis_url: str):
|
||||
"""
|
||||
Initialize Redis connection for forecast caching
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL
|
||||
"""
|
||||
self.redis_url = redis_url
|
||||
self._redis_client = None
|
||||
self._connect()
|
||||
|
||||
def _connect(self):
|
||||
"""Establish Redis connection with retry logic"""
|
||||
try:
|
||||
self._redis_client = redis.from_url(
|
||||
self.redis_url,
|
||||
decode_responses=True,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options={1: 1, 3: 3, 5: 5},
|
||||
retry_on_timeout=True,
|
||||
max_connections=100, # Higher limit for forecast service
|
||||
health_check_interval=30
|
||||
)
|
||||
# Test connection
|
||||
self._redis_client.ping()
|
||||
logger.info("Forecast cache Redis connection established")
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to forecast cache Redis", error=str(e))
|
||||
self._redis_client = None
|
||||
|
||||
@property
|
||||
def redis(self):
|
||||
"""Get Redis client with connection check"""
|
||||
if self._redis_client is None:
|
||||
self._connect()
|
||||
return self._redis_client
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Redis cache is available"""
|
||||
try:
|
||||
return self.redis is not None and self.redis.ping()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# ================================================================
|
||||
# FORECAST CACHING
|
||||
# ================================================================
|
||||
|
||||
def _get_forecast_key(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_id: UUID,
|
||||
forecast_date: date
|
||||
) -> str:
|
||||
"""Generate cache key for forecast"""
|
||||
return f"forecast:{tenant_id}:{product_id}:{forecast_date.isoformat()}"
|
||||
|
||||
def _get_batch_forecast_key(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_ids: List[UUID],
|
||||
forecast_date: date
|
||||
) -> str:
|
||||
"""Generate cache key for batch forecast"""
|
||||
# Sort product IDs for consistent key generation
|
||||
sorted_ids = sorted(str(pid) for pid in product_ids)
|
||||
products_hash = hash(tuple(sorted_ids))
|
||||
return f"forecast:batch:{tenant_id}:{products_hash}:{forecast_date.isoformat()}"
|
||||
|
||||
def _calculate_ttl(self, forecast_date: date) -> int:
|
||||
"""
|
||||
Calculate TTL for forecast cache entry
|
||||
|
||||
Forecasts expire at midnight of the day after forecast_date.
|
||||
This ensures forecasts remain cached throughout the forecasted day
|
||||
but don't become stale.
|
||||
|
||||
Args:
|
||||
forecast_date: Date of the forecast
|
||||
|
||||
Returns:
|
||||
TTL in seconds
|
||||
"""
|
||||
# Expire at midnight after forecast_date
|
||||
expiry_datetime = datetime.combine(
|
||||
forecast_date + timedelta(days=1),
|
||||
datetime.min.time()
|
||||
)
|
||||
|
||||
now = datetime.now()
|
||||
ttl_seconds = int((expiry_datetime - now).total_seconds())
|
||||
|
||||
# Minimum TTL of 1 hour, maximum of 48 hours
|
||||
return max(3600, min(ttl_seconds, 172800))
|
||||
|
||||
async def get_cached_forecast(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_id: UUID,
|
||||
forecast_date: date
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve cached forecast if available
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_id: Product identifier
|
||||
forecast_date: Date of forecast
|
||||
|
||||
Returns:
|
||||
Cached forecast data or None if not found
|
||||
"""
|
||||
if not self.is_available():
|
||||
return None
|
||||
|
||||
try:
|
||||
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
|
||||
cached_data = self.redis.get(key)
|
||||
|
||||
if cached_data:
|
||||
forecast_data = json.loads(cached_data)
|
||||
# Add cache hit metadata
|
||||
forecast_data['cached'] = True
|
||||
forecast_data['cache_hit_at'] = datetime.now().isoformat()
|
||||
|
||||
logger.info("Forecast cache HIT",
|
||||
tenant_id=str(tenant_id),
|
||||
product_id=str(product_id),
|
||||
forecast_date=str(forecast_date))
|
||||
return forecast_data
|
||||
|
||||
logger.debug("Forecast cache MISS",
|
||||
tenant_id=str(tenant_id),
|
||||
product_id=str(product_id),
|
||||
forecast_date=str(forecast_date))
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving cached forecast",
|
||||
error=str(e),
|
||||
tenant_id=str(tenant_id))
|
||||
return None
|
||||
|
||||
async def cache_forecast(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_id: UUID,
|
||||
forecast_date: date,
|
||||
forecast_data: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Cache forecast prediction result
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_id: Product identifier
|
||||
forecast_date: Date of forecast
|
||||
forecast_data: Forecast prediction data to cache
|
||||
|
||||
Returns:
|
||||
True if cached successfully, False otherwise
|
||||
"""
|
||||
if not self.is_available():
|
||||
logger.warning("Redis not available, skipping forecast cache")
|
||||
return False
|
||||
|
||||
try:
|
||||
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
|
||||
ttl = self._calculate_ttl(forecast_date)
|
||||
|
||||
# Add caching metadata
|
||||
cache_entry = {
|
||||
**forecast_data,
|
||||
'cached_at': datetime.now().isoformat(),
|
||||
'cache_key': key,
|
||||
'ttl_seconds': ttl
|
||||
}
|
||||
|
||||
# Serialize and cache
|
||||
self.redis.setex(
|
||||
key,
|
||||
ttl,
|
||||
json.dumps(cache_entry, default=str)
|
||||
)
|
||||
|
||||
logger.info("Forecast cached successfully",
|
||||
tenant_id=str(tenant_id),
|
||||
product_id=str(product_id),
|
||||
forecast_date=str(forecast_date),
|
||||
ttl_hours=round(ttl / 3600, 2))
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error caching forecast",
|
||||
error=str(e),
|
||||
tenant_id=str(tenant_id))
|
||||
return False
|
||||
|
||||
async def get_cached_batch_forecast(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_ids: List[UUID],
|
||||
forecast_date: date
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve cached batch forecast
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_ids: List of product identifiers
|
||||
forecast_date: Date of forecast
|
||||
|
||||
Returns:
|
||||
Cached batch forecast data or None
|
||||
"""
|
||||
if not self.is_available():
|
||||
return None
|
||||
|
||||
try:
|
||||
key = self._get_batch_forecast_key(tenant_id, product_ids, forecast_date)
|
||||
cached_data = self.redis.get(key)
|
||||
|
||||
if cached_data:
|
||||
forecast_data = json.loads(cached_data)
|
||||
forecast_data['cached'] = True
|
||||
forecast_data['cache_hit_at'] = datetime.now().isoformat()
|
||||
|
||||
logger.info("Batch forecast cache HIT",
|
||||
tenant_id=str(tenant_id),
|
||||
products_count=len(product_ids),
|
||||
forecast_date=str(forecast_date))
|
||||
return forecast_data
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving cached batch forecast", error=str(e))
|
||||
return None
|
||||
|
||||
async def cache_batch_forecast(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_ids: List[UUID],
|
||||
forecast_date: date,
|
||||
forecast_data: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Cache batch forecast result"""
|
||||
if not self.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
key = self._get_batch_forecast_key(tenant_id, product_ids, forecast_date)
|
||||
ttl = self._calculate_ttl(forecast_date)
|
||||
|
||||
cache_entry = {
|
||||
**forecast_data,
|
||||
'cached_at': datetime.now().isoformat(),
|
||||
'cache_key': key,
|
||||
'ttl_seconds': ttl
|
||||
}
|
||||
|
||||
self.redis.setex(key, ttl, json.dumps(cache_entry, default=str))
|
||||
|
||||
logger.info("Batch forecast cached successfully",
|
||||
tenant_id=str(tenant_id),
|
||||
products_count=len(product_ids),
|
||||
ttl_hours=round(ttl / 3600, 2))
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error caching batch forecast", error=str(e))
|
||||
return False
|
||||
|
||||
# ================================================================
|
||||
# CACHE INVALIDATION
|
||||
# ================================================================
|
||||
|
||||
async def invalidate_product_forecasts(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_id: UUID
|
||||
) -> int:
|
||||
"""
|
||||
Invalidate all forecast cache entries for a product
|
||||
|
||||
Called when model is retrained for specific product.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_id: Product identifier
|
||||
|
||||
Returns:
|
||||
Number of cache entries invalidated
|
||||
"""
|
||||
if not self.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
# Find all keys matching this product
|
||||
pattern = f"forecast:{tenant_id}:{product_id}:*"
|
||||
keys = self.redis.keys(pattern)
|
||||
|
||||
if keys:
|
||||
deleted = self.redis.delete(*keys)
|
||||
logger.info("Invalidated product forecast cache",
|
||||
tenant_id=str(tenant_id),
|
||||
product_id=str(product_id),
|
||||
keys_deleted=deleted)
|
||||
return deleted
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error invalidating product forecasts",
|
||||
error=str(e),
|
||||
tenant_id=str(tenant_id))
|
||||
return 0
|
||||
|
||||
async def invalidate_tenant_forecasts(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
forecast_date: Optional[date] = None
|
||||
) -> int:
|
||||
"""
|
||||
Invalidate forecast cache for tenant
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
forecast_date: Optional specific date to invalidate
|
||||
|
||||
Returns:
|
||||
Number of cache entries invalidated
|
||||
"""
|
||||
if not self.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
if forecast_date:
|
||||
pattern = f"forecast:{tenant_id}:*:{forecast_date.isoformat()}"
|
||||
else:
|
||||
pattern = f"forecast:{tenant_id}:*"
|
||||
|
||||
keys = self.redis.keys(pattern)
|
||||
|
||||
if keys:
|
||||
deleted = self.redis.delete(*keys)
|
||||
logger.info("Invalidated tenant forecast cache",
|
||||
tenant_id=str(tenant_id),
|
||||
forecast_date=str(forecast_date) if forecast_date else "all",
|
||||
keys_deleted=deleted)
|
||||
return deleted
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error invalidating tenant forecasts", error=str(e))
|
||||
return 0
|
||||
|
||||
async def invalidate_all_forecasts(self) -> int:
|
||||
"""
|
||||
Invalidate all forecast cache entries (use with caution)
|
||||
|
||||
Returns:
|
||||
Number of cache entries invalidated
|
||||
"""
|
||||
if not self.is_available():
|
||||
return 0
|
||||
|
||||
try:
|
||||
pattern = "forecast:*"
|
||||
keys = self.redis.keys(pattern)
|
||||
|
||||
if keys:
|
||||
deleted = self.redis.delete(*keys)
|
||||
logger.warning("Invalidated ALL forecast cache", keys_deleted=deleted)
|
||||
return deleted
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error invalidating all forecasts", error=str(e))
|
||||
return 0
|
||||
|
||||
# ================================================================
|
||||
# CACHE STATISTICS & MONITORING
|
||||
# ================================================================
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics for monitoring
|
||||
|
||||
Returns:
|
||||
Dictionary with cache metrics
|
||||
"""
|
||||
if not self.is_available():
|
||||
return {"available": False}
|
||||
|
||||
try:
|
||||
info = self.redis.info()
|
||||
|
||||
# Get forecast-specific stats
|
||||
forecast_keys = self.redis.keys("forecast:*")
|
||||
batch_keys = self.redis.keys("forecast:batch:*")
|
||||
|
||||
return {
|
||||
"available": True,
|
||||
"total_forecast_keys": len(forecast_keys),
|
||||
"batch_forecast_keys": len(batch_keys),
|
||||
"single_forecast_keys": len(forecast_keys) - len(batch_keys),
|
||||
"used_memory": info.get("used_memory_human"),
|
||||
"connected_clients": info.get("connected_clients"),
|
||||
"keyspace_hits": info.get("keyspace_hits", 0),
|
||||
"keyspace_misses": info.get("keyspace_misses", 0),
|
||||
"hit_rate_percent": self._calculate_hit_rate(
|
||||
info.get("keyspace_hits", 0),
|
||||
info.get("keyspace_misses", 0)
|
||||
),
|
||||
"total_commands_processed": info.get("total_commands_processed", 0)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Error getting cache stats", error=str(e))
|
||||
return {"available": False, "error": str(e)}
|
||||
|
||||
def _calculate_hit_rate(self, hits: int, misses: int) -> float:
|
||||
"""Calculate cache hit rate percentage"""
|
||||
total = hits + misses
|
||||
return round((hits / total * 100), 2) if total > 0 else 0.0
|
||||
|
||||
async def get_cached_forecast_info(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_id: UUID,
|
||||
forecast_date: date
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get metadata about cached forecast without retrieving full data
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_id: Product identifier
|
||||
forecast_date: Date of forecast
|
||||
|
||||
Returns:
|
||||
Cache metadata or None
|
||||
"""
|
||||
if not self.is_available():
|
||||
return None
|
||||
|
||||
try:
|
||||
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
|
||||
ttl = self.redis.ttl(key)
|
||||
|
||||
if ttl > 0:
|
||||
return {
|
||||
"cached": True,
|
||||
"cache_key": key,
|
||||
"ttl_seconds": ttl,
|
||||
"ttl_hours": round(ttl / 3600, 2),
|
||||
"expires_at": (datetime.now() + timedelta(seconds=ttl)).isoformat()
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting forecast cache info", error=str(e))
|
||||
return None
|
||||
|
||||
|
||||
# Global cache service instance
|
||||
_cache_service = None
|
||||
|
||||
|
||||
def get_forecast_cache_service(redis_url: Optional[str] = None) -> ForecastCacheService:
|
||||
"""
|
||||
Get the global forecast cache service instance
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL (required for first call)
|
||||
|
||||
Returns:
|
||||
ForecastCacheService instance
|
||||
"""
|
||||
global _cache_service
|
||||
|
||||
if _cache_service is None:
|
||||
if redis_url is None:
|
||||
raise ValueError("redis_url required for first initialization")
|
||||
_cache_service = ForecastCacheService(redis_url)
|
||||
|
||||
return _cache_service
|
||||
Reference in New Issue
Block a user