496 lines
15 KiB
Python
496 lines
15 KiB
Python
# services/forecasting/app/services/forecast_cache.py
|
|
"""
|
|
Forecast Cache Service - Redis-based caching for forecast results
|
|
|
|
Provides service-level caching for forecast predictions to eliminate redundant
|
|
computations when multiple services (Orders, Production) request the same
|
|
forecast data within a short time window.
|
|
|
|
Cache Strategy:
|
|
- Key: forecast:{tenant_id}:{product_id}:{forecast_date}
|
|
- TTL: Until midnight of day after forecast_date
|
|
- Invalidation: On model retraining for specific products
|
|
- Metadata: Includes 'cached' flag for observability
|
|
"""
|
|
|
|
import json
|
|
from datetime import datetime, date, timedelta
|
|
from typing import Optional, Dict, Any, List
|
|
from uuid import UUID
|
|
import structlog
|
|
from shared.redis_utils import get_redis_client
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class ForecastCacheService:
|
|
"""Service-level caching for forecast predictions"""
|
|
|
|
def __init__(self):
|
|
"""Initialize forecast cache service"""
|
|
pass
|
|
|
|
async def _get_redis(self):
|
|
"""Get shared Redis client"""
|
|
return await get_redis_client()
|
|
|
|
async def is_available(self) -> bool:
|
|
"""Check if Redis cache is available"""
|
|
try:
|
|
client = await self._get_redis()
|
|
await client.ping()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
# ================================================================
|
|
# FORECAST CACHING
|
|
# ================================================================
|
|
|
|
def _get_forecast_key(
|
|
self,
|
|
tenant_id: UUID,
|
|
product_id: UUID,
|
|
forecast_date: date
|
|
) -> str:
|
|
"""Generate cache key for forecast"""
|
|
return f"forecast:{tenant_id}:{product_id}:{forecast_date.isoformat()}"
|
|
|
|
def _get_batch_forecast_key(
|
|
self,
|
|
tenant_id: UUID,
|
|
product_ids: List[UUID],
|
|
forecast_date: date
|
|
) -> str:
|
|
"""Generate cache key for batch forecast"""
|
|
# Sort product IDs for consistent key generation
|
|
sorted_ids = sorted(str(pid) for pid in product_ids)
|
|
products_hash = hash(tuple(sorted_ids))
|
|
return f"forecast:batch:{tenant_id}:{products_hash}:{forecast_date.isoformat()}"
|
|
|
|
def _calculate_ttl(self, forecast_date: date) -> int:
|
|
"""
|
|
Calculate TTL for forecast cache entry
|
|
|
|
Forecasts expire at midnight of the day after forecast_date.
|
|
This ensures forecasts remain cached throughout the forecasted day
|
|
but don't become stale.
|
|
|
|
Args:
|
|
forecast_date: Date of the forecast
|
|
|
|
Returns:
|
|
TTL in seconds
|
|
"""
|
|
# Expire at midnight after forecast_date
|
|
expiry_datetime = datetime.combine(
|
|
forecast_date + timedelta(days=1),
|
|
datetime.min.time()
|
|
)
|
|
|
|
now = datetime.now()
|
|
ttl_seconds = int((expiry_datetime - now).total_seconds())
|
|
|
|
# Minimum TTL of 1 hour, maximum of 48 hours
|
|
return max(3600, min(ttl_seconds, 172800))
|
|
|
|
async def get_cached_forecast(
|
|
self,
|
|
tenant_id: UUID,
|
|
product_id: UUID,
|
|
forecast_date: date
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Retrieve cached forecast if available
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
product_id: Product identifier
|
|
forecast_date: Date of forecast
|
|
|
|
Returns:
|
|
Cached forecast data or None if not found
|
|
"""
|
|
if not await self.is_available():
|
|
return None
|
|
|
|
try:
|
|
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
|
|
client = await self._get_redis()
|
|
cached_data = await client.get(key)
|
|
|
|
if cached_data:
|
|
forecast_data = json.loads(cached_data)
|
|
# Add cache hit metadata
|
|
forecast_data['cached'] = True
|
|
forecast_data['cache_hit_at'] = datetime.now().isoformat()
|
|
|
|
logger.info("Forecast cache HIT",
|
|
tenant_id=str(tenant_id),
|
|
product_id=str(product_id),
|
|
forecast_date=str(forecast_date))
|
|
return forecast_data
|
|
|
|
logger.debug("Forecast cache MISS",
|
|
tenant_id=str(tenant_id),
|
|
product_id=str(product_id),
|
|
forecast_date=str(forecast_date))
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error("Error retrieving cached forecast",
|
|
error=str(e),
|
|
tenant_id=str(tenant_id))
|
|
return None
|
|
|
|
async def cache_forecast(
|
|
self,
|
|
tenant_id: UUID,
|
|
product_id: UUID,
|
|
forecast_date: date,
|
|
forecast_data: Dict[str, Any]
|
|
) -> bool:
|
|
"""
|
|
Cache forecast prediction result
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
product_id: Product identifier
|
|
forecast_date: Date of forecast
|
|
forecast_data: Forecast prediction data to cache
|
|
|
|
Returns:
|
|
True if cached successfully, False otherwise
|
|
"""
|
|
if not await self.is_available():
|
|
logger.warning("Redis not available, skipping forecast cache")
|
|
return False
|
|
|
|
try:
|
|
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
|
|
ttl = self._calculate_ttl(forecast_date)
|
|
|
|
# Add caching metadata
|
|
cache_entry = {
|
|
**forecast_data,
|
|
'cached_at': datetime.now().isoformat(),
|
|
'cache_key': key,
|
|
'ttl_seconds': ttl
|
|
}
|
|
|
|
# Serialize and cache
|
|
client = await self._get_redis()
|
|
await client.setex(
|
|
key,
|
|
ttl,
|
|
json.dumps(cache_entry, default=str)
|
|
)
|
|
|
|
logger.info("Forecast cached successfully",
|
|
tenant_id=str(tenant_id),
|
|
product_id=str(product_id),
|
|
forecast_date=str(forecast_date),
|
|
ttl_hours=round(ttl / 3600, 2))
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error("Error caching forecast",
|
|
error=str(e),
|
|
tenant_id=str(tenant_id))
|
|
return False
|
|
|
|
async def get_cached_batch_forecast(
|
|
self,
|
|
tenant_id: UUID,
|
|
product_ids: List[UUID],
|
|
forecast_date: date
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Retrieve cached batch forecast
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
product_ids: List of product identifiers
|
|
forecast_date: Date of forecast
|
|
|
|
Returns:
|
|
Cached batch forecast data or None
|
|
"""
|
|
if not await self.is_available():
|
|
return None
|
|
|
|
try:
|
|
key = self._get_batch_forecast_key(tenant_id, product_ids, forecast_date)
|
|
client = await self._get_redis()
|
|
cached_data = await client.get(key)
|
|
|
|
if cached_data:
|
|
forecast_data = json.loads(cached_data)
|
|
forecast_data['cached'] = True
|
|
forecast_data['cache_hit_at'] = datetime.now().isoformat()
|
|
|
|
logger.info("Batch forecast cache HIT",
|
|
tenant_id=str(tenant_id),
|
|
products_count=len(product_ids),
|
|
forecast_date=str(forecast_date))
|
|
return forecast_data
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error("Error retrieving cached batch forecast", error=str(e))
|
|
return None
|
|
|
|
async def cache_batch_forecast(
|
|
self,
|
|
tenant_id: UUID,
|
|
product_ids: List[UUID],
|
|
forecast_date: date,
|
|
forecast_data: Dict[str, Any]
|
|
) -> bool:
|
|
"""Cache batch forecast result"""
|
|
if not await self.is_available():
|
|
return False
|
|
|
|
try:
|
|
key = self._get_batch_forecast_key(tenant_id, product_ids, forecast_date)
|
|
ttl = self._calculate_ttl(forecast_date)
|
|
|
|
cache_entry = {
|
|
**forecast_data,
|
|
'cached_at': datetime.now().isoformat(),
|
|
'cache_key': key,
|
|
'ttl_seconds': ttl
|
|
}
|
|
|
|
client = await self._get_redis()
|
|
await client.setex(key, ttl, json.dumps(cache_entry, default=str))
|
|
|
|
logger.info("Batch forecast cached successfully",
|
|
tenant_id=str(tenant_id),
|
|
products_count=len(product_ids),
|
|
ttl_hours=round(ttl / 3600, 2))
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error("Error caching batch forecast", error=str(e))
|
|
return False
|
|
|
|
# ================================================================
|
|
# CACHE INVALIDATION
|
|
# ================================================================
|
|
|
|
async def invalidate_product_forecasts(
|
|
self,
|
|
tenant_id: UUID,
|
|
product_id: UUID
|
|
) -> int:
|
|
"""
|
|
Invalidate all forecast cache entries for a product
|
|
|
|
Called when model is retrained for specific product.
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
product_id: Product identifier
|
|
|
|
Returns:
|
|
Number of cache entries invalidated
|
|
"""
|
|
if not await self.is_available():
|
|
return 0
|
|
|
|
try:
|
|
# Find all keys matching this product
|
|
pattern = f"forecast:{tenant_id}:{product_id}:*"
|
|
client = await self._get_redis()
|
|
keys = await client.keys(pattern)
|
|
|
|
if keys:
|
|
deleted = await client.delete(*keys)
|
|
logger.info("Invalidated product forecast cache",
|
|
tenant_id=str(tenant_id),
|
|
product_id=str(product_id),
|
|
keys_deleted=deleted)
|
|
return deleted
|
|
|
|
return 0
|
|
|
|
except Exception as e:
|
|
logger.error("Error invalidating product forecasts",
|
|
error=str(e),
|
|
tenant_id=str(tenant_id))
|
|
return 0
|
|
|
|
async def invalidate_tenant_forecasts(
|
|
self,
|
|
tenant_id: UUID,
|
|
forecast_date: Optional[date] = None
|
|
) -> int:
|
|
"""
|
|
Invalidate forecast cache for tenant
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
forecast_date: Optional specific date to invalidate
|
|
|
|
Returns:
|
|
Number of cache entries invalidated
|
|
"""
|
|
if not await self.is_available():
|
|
return 0
|
|
|
|
try:
|
|
if forecast_date:
|
|
pattern = f"forecast:{tenant_id}:*:{forecast_date.isoformat()}"
|
|
else:
|
|
pattern = f"forecast:{tenant_id}:*"
|
|
|
|
client = await self._get_redis()
|
|
keys = await client.keys(pattern)
|
|
|
|
if keys:
|
|
deleted = await client.delete(*keys)
|
|
logger.info("Invalidated tenant forecast cache",
|
|
tenant_id=str(tenant_id),
|
|
forecast_date=str(forecast_date) if forecast_date else "all",
|
|
keys_deleted=deleted)
|
|
return deleted
|
|
|
|
return 0
|
|
|
|
except Exception as e:
|
|
logger.error("Error invalidating tenant forecasts", error=str(e))
|
|
return 0
|
|
|
|
async def invalidate_all_forecasts(self) -> int:
|
|
"""
|
|
Invalidate all forecast cache entries (use with caution)
|
|
|
|
Returns:
|
|
Number of cache entries invalidated
|
|
"""
|
|
if not await self.is_available():
|
|
return 0
|
|
|
|
try:
|
|
pattern = "forecast:*"
|
|
client = await self._get_redis()
|
|
keys = await client.keys(pattern)
|
|
|
|
if keys:
|
|
deleted = await client.delete(*keys)
|
|
logger.warning("Invalidated ALL forecast cache", keys_deleted=deleted)
|
|
return deleted
|
|
|
|
return 0
|
|
|
|
except Exception as e:
|
|
logger.error("Error invalidating all forecasts", error=str(e))
|
|
return 0
|
|
|
|
# ================================================================
|
|
# CACHE STATISTICS & MONITORING
|
|
# ================================================================
|
|
|
|
async def get_cache_stats(self) -> Dict[str, Any]:
|
|
"""
|
|
Get cache statistics for monitoring
|
|
|
|
Returns:
|
|
Dictionary with cache metrics
|
|
"""
|
|
if not await self.is_available():
|
|
return {"available": False}
|
|
|
|
try:
|
|
client = await self._get_redis()
|
|
info = await client.info()
|
|
|
|
# Get forecast-specific stats
|
|
forecast_keys = await client.keys("forecast:*")
|
|
batch_keys = await client.keys("forecast:batch:*")
|
|
|
|
return {
|
|
"available": True,
|
|
"total_forecast_keys": len(forecast_keys),
|
|
"batch_forecast_keys": len(batch_keys),
|
|
"single_forecast_keys": len(forecast_keys) - len(batch_keys),
|
|
"used_memory": info.get("used_memory_human"),
|
|
"connected_clients": info.get("connected_clients"),
|
|
"keyspace_hits": info.get("keyspace_hits", 0),
|
|
"keyspace_misses": info.get("keyspace_misses", 0),
|
|
"hit_rate_percent": self._calculate_hit_rate(
|
|
info.get("keyspace_hits", 0),
|
|
info.get("keyspace_misses", 0)
|
|
),
|
|
"total_commands_processed": info.get("total_commands_processed", 0)
|
|
}
|
|
except Exception as e:
|
|
logger.error("Error getting cache stats", error=str(e))
|
|
return {"available": False, "error": str(e)}
|
|
|
|
def _calculate_hit_rate(self, hits: int, misses: int) -> float:
|
|
"""Calculate cache hit rate percentage"""
|
|
total = hits + misses
|
|
return round((hits / total * 100), 2) if total > 0 else 0.0
|
|
|
|
async def get_cached_forecast_info(
|
|
self,
|
|
tenant_id: UUID,
|
|
product_id: UUID,
|
|
forecast_date: date
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get metadata about cached forecast without retrieving full data
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
product_id: Product identifier
|
|
forecast_date: Date of forecast
|
|
|
|
Returns:
|
|
Cache metadata or None
|
|
"""
|
|
if not await self.is_available():
|
|
return None
|
|
|
|
try:
|
|
key = self._get_forecast_key(tenant_id, product_id, forecast_date)
|
|
client = await self._get_redis()
|
|
ttl = await client.ttl(key)
|
|
|
|
if ttl > 0:
|
|
return {
|
|
"cached": True,
|
|
"cache_key": key,
|
|
"ttl_seconds": ttl,
|
|
"ttl_hours": round(ttl / 3600, 2),
|
|
"expires_at": (datetime.now() + timedelta(seconds=ttl)).isoformat()
|
|
}
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error("Error getting forecast cache info", error=str(e))
|
|
return None
|
|
|
|
|
|
# Global cache service instance
|
|
_cache_service = None
|
|
|
|
|
|
def get_forecast_cache_service() -> ForecastCacheService:
|
|
"""
|
|
Get the global forecast cache service instance
|
|
|
|
Returns:
|
|
ForecastCacheService instance
|
|
"""
|
|
global _cache_service
|
|
|
|
if _cache_service is None:
|
|
_cache_service = ForecastCacheService()
|
|
|
|
return _cache_service
|