266 lines
7.4 KiB
Python
266 lines
7.4 KiB
Python
# services/orchestrator/app/utils/cache.py
|
|
"""
|
|
Redis caching utilities for dashboard endpoints
|
|
"""
|
|
|
|
import json
|
|
import redis.asyncio as redis
|
|
from typing import Optional, Any, Callable
|
|
from functools import wraps
|
|
import structlog
|
|
from app.core.config import settings
|
|
from pydantic import BaseModel
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
# Redis client instance
|
|
_redis_client: Optional[redis.Redis] = None
|
|
|
|
|
|
async def get_redis_client() -> redis.Redis:
|
|
"""Get or create Redis client"""
|
|
global _redis_client
|
|
|
|
if _redis_client is None:
|
|
try:
|
|
# Check if TLS is enabled - convert string to boolean properly
|
|
redis_tls_str = str(getattr(settings, 'REDIS_TLS_ENABLED', 'false')).lower()
|
|
redis_tls_enabled = redis_tls_str in ('true', '1', 'yes', 'on')
|
|
|
|
connection_kwargs = {
|
|
'host': str(getattr(settings, 'REDIS_HOST', 'localhost')),
|
|
'port': int(getattr(settings, 'REDIS_PORT', 6379)),
|
|
'db': int(getattr(settings, 'REDIS_DB', 0)),
|
|
'decode_responses': True,
|
|
'socket_connect_timeout': 5,
|
|
'socket_timeout': 5
|
|
}
|
|
|
|
# Add password if configured
|
|
redis_password = getattr(settings, 'REDIS_PASSWORD', None)
|
|
if redis_password:
|
|
connection_kwargs['password'] = redis_password
|
|
|
|
# Add SSL/TLS support if enabled
|
|
if redis_tls_enabled:
|
|
import ssl
|
|
connection_kwargs['ssl'] = True
|
|
connection_kwargs['ssl_cert_reqs'] = ssl.CERT_NONE
|
|
logger.debug(f"Redis TLS enabled - connecting with SSL to {connection_kwargs['host']}:{connection_kwargs['port']}")
|
|
|
|
_redis_client = redis.Redis(**connection_kwargs)
|
|
|
|
# Test connection
|
|
await _redis_client.ping()
|
|
logger.info(f"Redis client connected successfully (TLS: {redis_tls_enabled})")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to connect to Redis: {e}. Caching will be disabled.")
|
|
_redis_client = None
|
|
|
|
return _redis_client
|
|
|
|
|
|
async def close_redis():
|
|
"""Close Redis connection"""
|
|
global _redis_client
|
|
if _redis_client:
|
|
await _redis_client.close()
|
|
_redis_client = None
|
|
logger.info("Redis connection closed")
|
|
|
|
|
|
async def get_cached(key: str) -> Optional[Any]:
|
|
"""
|
|
Get cached value by key
|
|
|
|
Args:
|
|
key: Cache key
|
|
|
|
Returns:
|
|
Cached value (deserialized from JSON) or None if not found or error
|
|
"""
|
|
try:
|
|
client = await get_redis_client()
|
|
if not client:
|
|
return None
|
|
|
|
cached = await client.get(key)
|
|
if cached:
|
|
logger.debug(f"Cache hit: {key}")
|
|
return json.loads(cached)
|
|
else:
|
|
logger.debug(f"Cache miss: {key}")
|
|
return None
|
|
except Exception as e:
|
|
logger.warning(f"Cache get error for key {key}: {e}")
|
|
return None
|
|
|
|
|
|
def _serialize_value(value: Any) -> Any:
|
|
"""
|
|
Recursively serialize values for JSON storage, handling Pydantic models properly.
|
|
|
|
Args:
|
|
value: Value to serialize
|
|
|
|
Returns:
|
|
JSON-serializable value
|
|
"""
|
|
if isinstance(value, BaseModel):
|
|
# Convert Pydantic model to dictionary
|
|
return value.model_dump()
|
|
elif isinstance(value, (list, tuple)):
|
|
# Recursively serialize list/tuple elements
|
|
return [_serialize_value(item) for item in value]
|
|
elif isinstance(value, dict):
|
|
# Recursively serialize dictionary values
|
|
return {key: _serialize_value(val) for key, val in value.items()}
|
|
else:
|
|
# For other types, use default serialization
|
|
return value
|
|
|
|
|
|
async def set_cached(key: str, value: Any, ttl: int = 60) -> bool:
|
|
"""
|
|
Set cached value with TTL
|
|
|
|
Args:
|
|
key: Cache key
|
|
value: Value to cache (will be JSON serialized)
|
|
ttl: Time to live in seconds
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
client = await get_redis_client()
|
|
if not client:
|
|
return False
|
|
|
|
# Serialize value properly before JSON encoding
|
|
serialized_value = _serialize_value(value)
|
|
serialized = json.dumps(serialized_value)
|
|
await client.setex(key, ttl, serialized)
|
|
logger.debug(f"Cache set: {key} (TTL: {ttl}s)")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Cache set error for key {key}: {e}")
|
|
return False
|
|
|
|
|
|
async def delete_cached(key: str) -> bool:
|
|
"""
|
|
Delete cached value
|
|
|
|
Args:
|
|
key: Cache key
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
client = await get_redis_client()
|
|
if not client:
|
|
return False
|
|
|
|
await client.delete(key)
|
|
logger.debug(f"Cache deleted: {key}")
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Cache delete error for key {key}: {e}")
|
|
return False
|
|
|
|
|
|
async def delete_pattern(pattern: str) -> int:
|
|
"""
|
|
Delete all keys matching pattern
|
|
|
|
Args:
|
|
pattern: Redis key pattern (e.g., "dashboard:*")
|
|
|
|
Returns:
|
|
Number of keys deleted
|
|
"""
|
|
try:
|
|
client = await get_redis_client()
|
|
if not client:
|
|
return 0
|
|
|
|
keys = []
|
|
async for key in client.scan_iter(match=pattern):
|
|
keys.append(key)
|
|
|
|
if keys:
|
|
deleted = await client.delete(*keys)
|
|
logger.info(f"Deleted {deleted} keys matching pattern: {pattern}")
|
|
return deleted
|
|
return 0
|
|
except Exception as e:
|
|
logger.warning(f"Cache delete pattern error for {pattern}: {e}")
|
|
return 0
|
|
|
|
|
|
def cache_response(key_prefix: str, ttl: int = 60):
|
|
"""
|
|
Decorator to cache endpoint responses
|
|
|
|
Args:
|
|
key_prefix: Prefix for cache key (will be combined with tenant_id)
|
|
ttl: Time to live in seconds
|
|
|
|
Usage:
|
|
@cache_response("dashboard:health", ttl=30)
|
|
async def get_health(tenant_id: str):
|
|
...
|
|
"""
|
|
def decorator(func: Callable):
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
# Extract tenant_id from kwargs or args
|
|
tenant_id = kwargs.get('tenant_id')
|
|
if not tenant_id and args:
|
|
# Try to find tenant_id in args (assuming it's the first argument)
|
|
tenant_id = args[0] if len(args) > 0 else None
|
|
|
|
if not tenant_id:
|
|
# No tenant_id, skip caching
|
|
return await func(*args, **kwargs)
|
|
|
|
# Build cache key
|
|
cache_key = f"{key_prefix}:{tenant_id}"
|
|
|
|
# Try to get from cache
|
|
cached_value = await get_cached(cache_key)
|
|
if cached_value is not None:
|
|
return cached_value
|
|
|
|
# Execute function
|
|
result = await func(*args, **kwargs)
|
|
|
|
# Cache result
|
|
await set_cached(cache_key, result, ttl)
|
|
|
|
return result
|
|
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
def make_cache_key(prefix: str, tenant_id: str, **params) -> str:
|
|
"""
|
|
Create a cache key with optional parameters
|
|
|
|
Args:
|
|
prefix: Key prefix
|
|
tenant_id: Tenant ID
|
|
**params: Additional parameters to include in key
|
|
|
|
Returns:
|
|
Cache key string
|
|
"""
|
|
key_parts = [prefix, tenant_id]
|
|
for k, v in sorted(params.items()):
|
|
if v is not None:
|
|
key_parts.append(f"{k}:{v}")
|
|
return ":".join(key_parts)
|