Files
bakery-ia/services/forecasting/app/services/forecast_cache.py
2025-10-09 18:01:24 +02:00

519 lines
16 KiB
Python

# services/forecasting/app/services/forecast_cache.py
"""
Forecast Cache Service - Redis-based caching for forecast results
Provides service-level caching for forecast predictions to eliminate redundant
computations when multiple services (Orders, Production) request the same
forecast data within a short time window.
Cache Strategy:
- Key: forecast:{tenant_id}:{product_id}:{forecast_date}
- TTL: Until midnight of day after forecast_date
- Invalidation: On model retraining for specific products
- Metadata: Includes 'cached' flag for observability
"""
import json
import redis
from datetime import datetime, date, timedelta
from typing import Optional, Dict, Any, List
from uuid import UUID
import structlog
logger = structlog.get_logger()
class ForecastCacheService:
"""Service-level caching for forecast predictions"""
def __init__(self, redis_url: str):
"""
Initialize Redis connection for forecast caching
Args:
redis_url: Redis connection URL
"""
self.redis_url = redis_url
self._redis_client = None
self._connect()
def _connect(self):
"""Establish Redis connection with retry logic"""
try:
self._redis_client = redis.from_url(
self.redis_url,
decode_responses=True,
socket_keepalive=True,
socket_keepalive_options={1: 1, 3: 3, 5: 5},
retry_on_timeout=True,
max_connections=100, # Higher limit for forecast service
health_check_interval=30
)
# Test connection
self._redis_client.ping()
logger.info("Forecast cache Redis connection established")
except Exception as e:
logger.error("Failed to connect to forecast cache Redis", error=str(e))
self._redis_client = None
@property
def redis(self):
"""Get Redis client with connection check"""
if self._redis_client is None:
self._connect()
return self._redis_client
def is_available(self) -> bool:
"""Check if Redis cache is available"""
try:
return self.redis is not None and self.redis.ping()
except Exception:
return False
# ================================================================
# FORECAST CACHING
# ================================================================
def _get_forecast_key(
self,
tenant_id: UUID,
product_id: UUID,
forecast_date: date
) -> str:
"""Generate cache key for forecast"""
return f"forecast:{tenant_id}:{product_id}:{forecast_date.isoformat()}"
def _get_batch_forecast_key(
self,
tenant_id: UUID,
product_ids: List[UUID],
forecast_date: date
) -> str:
"""Generate cache key for batch forecast"""
# Sort product IDs for consistent key generation
sorted_ids = sorted(str(pid) for pid in product_ids)
products_hash = hash(tuple(sorted_ids))
return f"forecast:batch:{tenant_id}:{products_hash}:{forecast_date.isoformat()}"
def _calculate_ttl(self, forecast_date: date) -> int:
"""
Calculate TTL for forecast cache entry
Forecasts expire at midnight of the day after forecast_date.
This ensures forecasts remain cached throughout the forecasted day
but don't become stale.
Args:
forecast_date: Date of the forecast
Returns:
TTL in seconds
"""
# Expire at midnight after forecast_date
expiry_datetime = datetime.combine(
forecast_date + timedelta(days=1),
datetime.min.time()
)
now = datetime.now()
ttl_seconds = int((expiry_datetime - now).total_seconds())
# Minimum TTL of 1 hour, maximum of 48 hours
return max(3600, min(ttl_seconds, 172800))
async def get_cached_forecast(
self,
tenant_id: UUID,
product_id: UUID,
forecast_date: date
) -> Optional[Dict[str, Any]]:
"""
Retrieve cached forecast if available
Args:
tenant_id: Tenant identifier
product_id: Product identifier
forecast_date: Date of forecast
Returns:
Cached forecast data or None if not found
"""
if not self.is_available():
return None
try:
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
cached_data = self.redis.get(key)
if cached_data:
forecast_data = json.loads(cached_data)
# Add cache hit metadata
forecast_data['cached'] = True
forecast_data['cache_hit_at'] = datetime.now().isoformat()
logger.info("Forecast cache HIT",
tenant_id=str(tenant_id),
product_id=str(product_id),
forecast_date=str(forecast_date))
return forecast_data
logger.debug("Forecast cache MISS",
tenant_id=str(tenant_id),
product_id=str(product_id),
forecast_date=str(forecast_date))
return None
except Exception as e:
logger.error("Error retrieving cached forecast",
error=str(e),
tenant_id=str(tenant_id))
return None
async def cache_forecast(
self,
tenant_id: UUID,
product_id: UUID,
forecast_date: date,
forecast_data: Dict[str, Any]
) -> bool:
"""
Cache forecast prediction result
Args:
tenant_id: Tenant identifier
product_id: Product identifier
forecast_date: Date of forecast
forecast_data: Forecast prediction data to cache
Returns:
True if cached successfully, False otherwise
"""
if not self.is_available():
logger.warning("Redis not available, skipping forecast cache")
return False
try:
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
ttl = self._calculate_ttl(forecast_date)
# Add caching metadata
cache_entry = {
**forecast_data,
'cached_at': datetime.now().isoformat(),
'cache_key': key,
'ttl_seconds': ttl
}
# Serialize and cache
self.redis.setex(
key,
ttl,
json.dumps(cache_entry, default=str)
)
logger.info("Forecast cached successfully",
tenant_id=str(tenant_id),
product_id=str(product_id),
forecast_date=str(forecast_date),
ttl_hours=round(ttl / 3600, 2))
return True
except Exception as e:
logger.error("Error caching forecast",
error=str(e),
tenant_id=str(tenant_id))
return False
async def get_cached_batch_forecast(
self,
tenant_id: UUID,
product_ids: List[UUID],
forecast_date: date
) -> Optional[Dict[str, Any]]:
"""
Retrieve cached batch forecast
Args:
tenant_id: Tenant identifier
product_ids: List of product identifiers
forecast_date: Date of forecast
Returns:
Cached batch forecast data or None
"""
if not self.is_available():
return None
try:
key = self._get_batch_forecast_key(tenant_id, product_ids, forecast_date)
cached_data = self.redis.get(key)
if cached_data:
forecast_data = json.loads(cached_data)
forecast_data['cached'] = True
forecast_data['cache_hit_at'] = datetime.now().isoformat()
logger.info("Batch forecast cache HIT",
tenant_id=str(tenant_id),
products_count=len(product_ids),
forecast_date=str(forecast_date))
return forecast_data
return None
except Exception as e:
logger.error("Error retrieving cached batch forecast", error=str(e))
return None
async def cache_batch_forecast(
self,
tenant_id: UUID,
product_ids: List[UUID],
forecast_date: date,
forecast_data: Dict[str, Any]
) -> bool:
"""Cache batch forecast result"""
if not self.is_available():
return False
try:
key = self._get_batch_forecast_key(tenant_id, product_ids, forecast_date)
ttl = self._calculate_ttl(forecast_date)
cache_entry = {
**forecast_data,
'cached_at': datetime.now().isoformat(),
'cache_key': key,
'ttl_seconds': ttl
}
self.redis.setex(key, ttl, json.dumps(cache_entry, default=str))
logger.info("Batch forecast cached successfully",
tenant_id=str(tenant_id),
products_count=len(product_ids),
ttl_hours=round(ttl / 3600, 2))
return True
except Exception as e:
logger.error("Error caching batch forecast", error=str(e))
return False
# ================================================================
# CACHE INVALIDATION
# ================================================================
async def invalidate_product_forecasts(
self,
tenant_id: UUID,
product_id: UUID
) -> int:
"""
Invalidate all forecast cache entries for a product
Called when model is retrained for specific product.
Args:
tenant_id: Tenant identifier
product_id: Product identifier
Returns:
Number of cache entries invalidated
"""
if not self.is_available():
return 0
try:
# Find all keys matching this product
pattern = f"forecast:{tenant_id}:{product_id}:*"
keys = self.redis.keys(pattern)
if keys:
deleted = self.redis.delete(*keys)
logger.info("Invalidated product forecast cache",
tenant_id=str(tenant_id),
product_id=str(product_id),
keys_deleted=deleted)
return deleted
return 0
except Exception as e:
logger.error("Error invalidating product forecasts",
error=str(e),
tenant_id=str(tenant_id))
return 0
async def invalidate_tenant_forecasts(
self,
tenant_id: UUID,
forecast_date: Optional[date] = None
) -> int:
"""
Invalidate forecast cache for tenant
Args:
tenant_id: Tenant identifier
forecast_date: Optional specific date to invalidate
Returns:
Number of cache entries invalidated
"""
if not self.is_available():
return 0
try:
if forecast_date:
pattern = f"forecast:{tenant_id}:*:{forecast_date.isoformat()}"
else:
pattern = f"forecast:{tenant_id}:*"
keys = self.redis.keys(pattern)
if keys:
deleted = self.redis.delete(*keys)
logger.info("Invalidated tenant forecast cache",
tenant_id=str(tenant_id),
forecast_date=str(forecast_date) if forecast_date else "all",
keys_deleted=deleted)
return deleted
return 0
except Exception as e:
logger.error("Error invalidating tenant forecasts", error=str(e))
return 0
async def invalidate_all_forecasts(self) -> int:
"""
Invalidate all forecast cache entries (use with caution)
Returns:
Number of cache entries invalidated
"""
if not self.is_available():
return 0
try:
pattern = "forecast:*"
keys = self.redis.keys(pattern)
if keys:
deleted = self.redis.delete(*keys)
logger.warning("Invalidated ALL forecast cache", keys_deleted=deleted)
return deleted
return 0
except Exception as e:
logger.error("Error invalidating all forecasts", error=str(e))
return 0
# ================================================================
# CACHE STATISTICS & MONITORING
# ================================================================
def get_cache_stats(self) -> Dict[str, Any]:
"""
Get cache statistics for monitoring
Returns:
Dictionary with cache metrics
"""
if not self.is_available():
return {"available": False}
try:
info = self.redis.info()
# Get forecast-specific stats
forecast_keys = self.redis.keys("forecast:*")
batch_keys = self.redis.keys("forecast:batch:*")
return {
"available": True,
"total_forecast_keys": len(forecast_keys),
"batch_forecast_keys": len(batch_keys),
"single_forecast_keys": len(forecast_keys) - len(batch_keys),
"used_memory": info.get("used_memory_human"),
"connected_clients": info.get("connected_clients"),
"keyspace_hits": info.get("keyspace_hits", 0),
"keyspace_misses": info.get("keyspace_misses", 0),
"hit_rate_percent": self._calculate_hit_rate(
info.get("keyspace_hits", 0),
info.get("keyspace_misses", 0)
),
"total_commands_processed": info.get("total_commands_processed", 0)
}
except Exception as e:
logger.error("Error getting cache stats", error=str(e))
return {"available": False, "error": str(e)}
def _calculate_hit_rate(self, hits: int, misses: int) -> float:
"""Calculate cache hit rate percentage"""
total = hits + misses
return round((hits / total * 100), 2) if total > 0 else 0.0
async def get_cached_forecast_info(
self,
tenant_id: UUID,
product_id: UUID,
forecast_date: date
) -> Optional[Dict[str, Any]]:
"""
Get metadata about cached forecast without retrieving full data
Args:
tenant_id: Tenant identifier
product_id: Product identifier
forecast_date: Date of forecast
Returns:
Cache metadata or None
"""
if not self.is_available():
return None
try:
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
ttl = self.redis.ttl(key)
if ttl > 0:
return {
"cached": True,
"cache_key": key,
"ttl_seconds": ttl,
"ttl_hours": round(ttl / 3600, 2),
"expires_at": (datetime.now() + timedelta(seconds=ttl)).isoformat()
}
return None
except Exception as e:
logger.error("Error getting forecast cache info", error=str(e))
return None
# Global cache service instance
_cache_service = None
def get_forecast_cache_service(redis_url: Optional[str] = None) -> ForecastCacheService:
"""
Get the global forecast cache service instance
Args:
redis_url: Redis connection URL (required for first call)
Returns:
ForecastCacheService instance
"""
global _cache_service
if _cache_service is None:
if redis_url is None:
raise ValueError("redis_url required for first initialization")
_cache_service = ForecastCacheService(redis_url)
return _cache_service