Files
bakery-ia/shared/redis_utils/client.py

529 lines
17 KiB
Python
Raw Normal View History

"""
Redis client initialization and connection management
Provides standardized Redis connection for all services
"""
2026-01-24 19:28:29 +01:00
import ssl
import redis.asyncio as redis
2026-01-24 19:28:29 +01:00
from typing import Optional, Dict, Any
import structlog
from contextlib import asynccontextmanager
logger = structlog.get_logger()
2026-01-24 19:28:29 +01:00
def get_ssl_kwargs_for_url(redis_url: str) -> Dict[str, Any]:
"""
Get SSL kwargs for redis.from_url() based on the URL scheme.
Handles self-signed certificates by disabling certificate verification
when using rediss:// (TLS-enabled) URLs.
Args:
redis_url: Redis connection URL (redis:// or rediss://)
Returns:
Dict with SSL configuration kwargs
"""
if redis_url and redis_url.startswith("rediss://"):
return {
"ssl_cert_reqs": ssl.CERT_NONE, # Disable certificate verification
"ssl_ca_certs": None, # Don't require CA certificates
"ssl_certfile": None, # Don't require client cert
"ssl_keyfile": None, # Don't require client key
}
return {}
class RedisConnectionManager:
"""
2026-01-24 19:28:29 +01:00
Manages Redis connections with connection pooling and error handling.
Thread-safe singleton pattern for sharing connections across service.
Usage:
# Option 1: Using class method (recommended for new code)
manager = await RedisConnectionManager.create(redis_url)
client = manager.get_client()
# Option 2: Using instance method
manager = RedisConnectionManager()
await manager.initialize(redis_url)
client = manager.get_client()
# Don't forget to close when done
await manager.close()
"""
def __init__(self):
self._client: Optional[redis.Redis] = None
self._pool: Optional[redis.ConnectionPool] = None
2026-01-24 19:28:29 +01:00
self._redis_url: Optional[str] = None
self.logger = logger
2026-01-24 19:28:29 +01:00
@classmethod
async def create(
cls,
redis_url: str,
db: int = 0,
max_connections: int = 50,
decode_responses: bool = False,
retry_on_timeout: bool = True,
socket_keepalive: bool = True,
health_check_interval: int = 30
) -> "RedisConnectionManager":
"""
Factory method to create and initialize a RedisConnectionManager.
This is the recommended way to create Redis connections across all services.
Handles SSL/TLS configuration automatically for self-signed certificates.
Args:
redis_url: Redis connection URL (redis:// or rediss://)
db: Database number (0-15)
max_connections: Maximum connections in pool
decode_responses: Automatically decode responses to strings
retry_on_timeout: Retry on timeout errors
socket_keepalive: Enable TCP keepalive
health_check_interval: Health check interval in seconds
Returns:
Initialized RedisConnectionManager
Example:
from shared.redis_utils import RedisConnectionManager
manager = await RedisConnectionManager.create(settings.REDIS_URL)
client = manager.get_client()
await client.ping()
"""
instance = cls()
await instance.initialize(
redis_url=redis_url,
db=db,
max_connections=max_connections,
decode_responses=decode_responses,
retry_on_timeout=retry_on_timeout,
socket_keepalive=socket_keepalive,
health_check_interval=health_check_interval,
)
return instance
async def initialize(
self,
redis_url: str,
db: int = 0,
max_connections: int = 50,
decode_responses: bool = True,
retry_on_timeout: bool = True,
socket_keepalive: bool = True,
health_check_interval: int = 30
):
"""
Initialize Redis connection with pool
Args:
redis_url: Redis connection URL (redis://[:password]@host:port)
db: Database number (0-15)
max_connections: Maximum connections in pool
decode_responses: Automatically decode responses to strings
retry_on_timeout: Retry on timeout errors
socket_keepalive: Enable TCP keepalive
health_check_interval: Health check interval in seconds
"""
try:
2026-01-24 19:28:29 +01:00
self._redis_url = redis_url
# Create connection pool with SSL handling for self-signed certificates
2026-01-24 20:14:19 +01:00
# For Redis 6.4.0+, we need to handle SSL parameters correctly
if redis_url.startswith("rediss://"):
# Extract connection parameters from URL
from urllib.parse import urlparse
parsed_url = urlparse(redis_url)
# Build connection parameters for ConnectionPool
connection_params = {
'db': db,
'max_connections': max_connections,
'retry_on_timeout': retry_on_timeout,
'socket_keepalive': socket_keepalive,
'health_check_interval': health_check_interval
}
# Add password if present
if parsed_url.password:
connection_params['password'] = parsed_url.password
# Create connection pool (without SSL parameters - they go to the client)
self._pool = redis.ConnectionPool(
host=parsed_url.hostname,
port=parsed_url.port or 6379,
**connection_params
)
# Get SSL configuration for self-signed certificates
ssl_kwargs = get_ssl_kwargs_for_url(redis_url)
# Create Redis client with SSL parameters
client_params = {
'connection_pool': self._pool,
'decode_responses': decode_responses
}
if ssl_kwargs:
client_params['ssl'] = True
client_params['ssl_cert_reqs'] = ssl_kwargs.get('ssl_cert_reqs', ssl.CERT_NONE)
2026-01-24 21:33:40 +01:00
# For Kubernetes environments, try to use mounted TLS certificates
# These are typically mounted at /tls/redis-cert.pem, /tls/redis-key.pem, /tls/ca-cert.pem
import os
ca_certs_path = os.getenv('REDIS_CA_CERTS_PATH', '/tls/ca-cert.pem')
certfile_path = os.getenv('REDIS_CERTFILE_PATH', '/tls/redis-cert.pem')
keyfile_path = os.getenv('REDIS_KEYFILE_PATH', '/tls/redis-key.pem')
# Use environment variables or mounted files if they exist
if os.path.exists(ca_certs_path):
client_params['ssl_ca_certs'] = ca_certs_path
elif ssl_kwargs.get('ssl_ca_certs'):
client_params['ssl_ca_certs'] = ssl_kwargs.get('ssl_ca_certs')
if os.path.exists(certfile_path):
client_params['ssl_certfile'] = certfile_path
elif ssl_kwargs.get('ssl_certfile'):
client_params['ssl_certfile'] = ssl_kwargs.get('ssl_certfile')
if os.path.exists(keyfile_path):
client_params['ssl_keyfile'] = keyfile_path
elif ssl_kwargs.get('ssl_keyfile'):
client_params['ssl_keyfile'] = ssl_kwargs.get('ssl_keyfile')
# Add additional SSL context parameters for better compatibility
# These help with SSL handshake issues and protocol compatibility
client_params['ssl_check_hostname'] = False # Disable hostname verification for self-signed certs
# Add SSL context with specific protocol versions for better compatibility
# This helps with "wrong version number" and "unexpected eof" SSL errors
import ssl as ssl_module
ssl_context = ssl_module.create_default_context(
purpose=ssl_module.Purpose.SERVER_AUTH,
cafile=client_params.get('ssl_ca_certs')
)
ssl_context.check_hostname = False
ssl_context.verify_mode = client_params.get('ssl_cert_reqs', ssl_module.CERT_NONE)
# Set minimum TLS version for better security and compatibility
# TLS 1.2 is widely supported and secure enough for internal cluster communication
ssl_context.minimum_version = ssl_module.TLSVersion.TLSv1_2
# If client certificates are provided, load them
if client_params.get('ssl_certfile') and client_params.get('ssl_keyfile'):
ssl_context.load_cert_chain(
certfile=client_params.get('ssl_certfile'),
keyfile=client_params.get('ssl_keyfile')
)
client_params['ssl_context'] = ssl_context
# Debug: Log the SSL configuration being used
self.logger.debug(
"redis_ssl_config",
ssl_enabled=True,
ssl_cert_reqs=client_params.get('ssl_cert_reqs'),
ssl_ca_certs=client_params.get('ssl_ca_certs'),
ssl_certfile=client_params.get('ssl_certfile'),
ssl_keyfile=client_params.get('ssl_keyfile'),
ssl_check_hostname=False,
ssl_minimum_version="TLSv1_2"
)
2026-01-24 20:14:19 +01:00
self._client = redis.Redis(**client_params)
else:
# For non-TLS connections, use the original approach
connection_kwargs = {
'db': db,
'max_connections': max_connections,
'decode_responses': decode_responses,
'retry_on_timeout': retry_on_timeout,
'socket_keepalive': socket_keepalive,
'health_check_interval': health_check_interval
}
# Add SSL kwargs for self-signed certificates (using shared helper)
connection_kwargs.update(get_ssl_kwargs_for_url(redis_url))
self._pool = redis.ConnectionPool.from_url(
redis_url,
**connection_kwargs
)
2026-01-24 21:33:40 +01:00
self._client = redis.Redis(connection_pool=self._pool)
# Test connection
await self._client.ping()
self.logger.info(
"redis_initialized",
redis_url=redis_url.split("@")[-1], # Log only host:port, not password
db=db,
max_connections=max_connections
)
except Exception as e:
self.logger.error(
"redis_initialization_failed",
error=str(e),
redis_url=redis_url.split("@")[-1]
)
raise
async def close(self):
"""Close Redis connection and pool"""
if self._client:
await self._client.close()
self.logger.info("redis_client_closed")
if self._pool:
await self._pool.disconnect()
self.logger.info("redis_pool_closed")
def get_client(self) -> redis.Redis:
"""
Get Redis client instance
Returns:
Redis client
Raises:
RuntimeError: If client not initialized
"""
if self._client is None:
raise RuntimeError("Redis client not initialized. Call initialize() first.")
return self._client
async def health_check(self) -> bool:
"""
Check Redis connection health
Returns:
bool: True if healthy, False otherwise
"""
try:
if self._client is None:
return False
await self._client.ping()
return True
except Exception as e:
self.logger.error("redis_health_check_failed", error=str(e))
return False
async def get_info(self) -> dict:
"""
Get Redis server information
Returns:
dict: Redis INFO command output
"""
try:
if self._client is None:
return {}
return await self._client.info()
except Exception as e:
self.logger.error("redis_info_failed", error=str(e))
return {}
async def flush_db(self):
"""
Flush current database (USE WITH CAUTION)
Only for development/testing
"""
try:
if self._client is None:
raise RuntimeError("Redis client not initialized")
await self._client.flushdb()
self.logger.warning("redis_database_flushed")
except Exception as e:
self.logger.error("redis_flush_failed", error=str(e))
raise
# Global connection manager instance
_redis_manager: Optional[RedisConnectionManager] = None
async def get_redis_manager() -> RedisConnectionManager:
"""
Get or create global Redis manager instance
Returns:
RedisConnectionManager instance
"""
global _redis_manager
if _redis_manager is None:
_redis_manager = RedisConnectionManager()
return _redis_manager
async def initialize_redis(
redis_url: str,
db: int = 0,
max_connections: int = 50,
**kwargs
) -> redis.Redis:
"""
Initialize Redis and return client
Args:
redis_url: Redis connection URL
db: Database number
max_connections: Maximum connections in pool
**kwargs: Additional connection parameters
Returns:
Redis client instance
"""
manager = await get_redis_manager()
await manager.initialize(
redis_url=redis_url,
db=db,
max_connections=max_connections,
**kwargs
)
return manager.get_client()
async def get_redis_client() -> redis.Redis:
"""
Get initialized Redis client
Returns:
Redis client instance
Raises:
RuntimeError: If Redis not initialized
"""
manager = await get_redis_manager()
return manager.get_client()
async def close_redis():
"""Close Redis connections"""
global _redis_manager
if _redis_manager:
await _redis_manager.close()
_redis_manager = None
@asynccontextmanager
async def redis_context(redis_url: str, db: int = 0):
"""
Context manager for Redis connections
Usage:
async with redis_context(settings.REDIS_URL) as client:
await client.set("key", "value")
Args:
redis_url: Redis connection URL
db: Database number
Yields:
Redis client
"""
client = None
try:
client = await initialize_redis(redis_url, db=db)
yield client
finally:
if client:
await close_redis()
# Convenience functions for common operations
async def set_with_ttl(key: str, value: str, ttl: int) -> bool:
"""
Set key with TTL
Args:
key: Redis key
value: Value to set
ttl: Time to live in seconds
Returns:
bool: True if successful
"""
try:
client = await get_redis_client()
await client.setex(key, ttl, value)
return True
except Exception as e:
logger.error("redis_set_failed", key=key, error=str(e))
return False
async def get_value(key: str) -> Optional[str]:
"""
Get value by key
Args:
key: Redis key
Returns:
Value or None if not found
"""
try:
client = await get_redis_client()
return await client.get(key)
except Exception as e:
logger.error("redis_get_failed", key=key, error=str(e))
return None
async def increment_counter(key: str, amount: int = 1, ttl: Optional[int] = None) -> int:
"""
Increment counter with optional TTL
Args:
key: Redis key
amount: Amount to increment
ttl: Time to live in seconds (sets on first increment)
Returns:
New counter value
"""
try:
client = await get_redis_client()
new_value = await client.incrby(key, amount)
# Set TTL if specified and key is new (value == amount)
if ttl and new_value == amount:
await client.expire(key, ttl)
return new_value
except Exception as e:
logger.error("redis_increment_failed", key=key, error=str(e))
return 0
async def get_keys_pattern(pattern: str) -> list:
"""
Get keys matching pattern
Args:
pattern: Redis key pattern (e.g., "quota:*")
Returns:
List of matching keys
"""
try:
client = await get_redis_client()
return await client.keys(pattern)
except Exception as e:
logger.error("redis_keys_failed", pattern=pattern, error=str(e))
return []