Files
bakery-ia/shared/redis_utils/client.py

344 lines
9.0 KiB
Python
Raw Normal View History

2026-01-21 17:17:16 +01:00
"""
Redis client initialization and connection management
Provides standardized Redis connection for all services
"""
import redis.asyncio as redis
from typing import Optional
import structlog
from contextlib import asynccontextmanager
logger = structlog.get_logger()
class RedisConnectionManager:
"""
Manages Redis connections with connection pooling and error handling
Thread-safe singleton pattern for sharing connections across service
"""
def __init__(self):
self._client: Optional[redis.Redis] = None
self._pool: Optional[redis.ConnectionPool] = None
self.logger = logger
async def initialize(
self,
redis_url: str,
db: int = 0,
max_connections: int = 50,
decode_responses: bool = True,
retry_on_timeout: bool = True,
socket_keepalive: bool = True,
health_check_interval: int = 30
):
"""
Initialize Redis connection with pool
Args:
redis_url: Redis connection URL (redis://[:password]@host:port)
db: Database number (0-15)
max_connections: Maximum connections in pool
decode_responses: Automatically decode responses to strings
retry_on_timeout: Retry on timeout errors
socket_keepalive: Enable TCP keepalive
health_check_interval: Health check interval in seconds
"""
try:
# Create connection pool
# Handle SSL parameters for self-signed certificates
connection_kwargs = {
'db': db,
'max_connections': max_connections,
'decode_responses': decode_responses,
'retry_on_timeout': retry_on_timeout,
'socket_keepalive': socket_keepalive,
'health_check_interval': health_check_interval
}
# If using SSL/TLS, add SSL parameters to handle self-signed certificates
if redis_url.startswith('rediss://'):
connection_kwargs.update({
'ssl_cert_reqs': None, # Disable certificate verification
'ssl_ca_certs': None, # Don't require CA certificates
'ssl_certfile': None, # Don't require client cert
'ssl_keyfile': None # Don't require client key
})
self._pool = redis.ConnectionPool.from_url(
redis_url,
**connection_kwargs
)
# Create Redis client with pool
self._client = redis.Redis(connection_pool=self._pool)
# Test connection
await self._client.ping()
self.logger.info(
"redis_initialized",
redis_url=redis_url.split("@")[-1], # Log only host:port, not password
db=db,
max_connections=max_connections
)
except Exception as e:
self.logger.error(
"redis_initialization_failed",
error=str(e),
redis_url=redis_url.split("@")[-1]
)
raise
async def close(self):
"""Close Redis connection and pool"""
if self._client:
await self._client.close()
self.logger.info("redis_client_closed")
if self._pool:
await self._pool.disconnect()
self.logger.info("redis_pool_closed")
def get_client(self) -> redis.Redis:
"""
Get Redis client instance
Returns:
Redis client
Raises:
RuntimeError: If client not initialized
"""
if self._client is None:
raise RuntimeError("Redis client not initialized. Call initialize() first.")
return self._client
async def health_check(self) -> bool:
"""
Check Redis connection health
Returns:
bool: True if healthy, False otherwise
"""
try:
if self._client is None:
return False
await self._client.ping()
return True
except Exception as e:
self.logger.error("redis_health_check_failed", error=str(e))
return False
async def get_info(self) -> dict:
"""
Get Redis server information
Returns:
dict: Redis INFO command output
"""
try:
if self._client is None:
return {}
return await self._client.info()
except Exception as e:
self.logger.error("redis_info_failed", error=str(e))
return {}
async def flush_db(self):
"""
Flush current database (USE WITH CAUTION)
Only for development/testing
"""
try:
if self._client is None:
raise RuntimeError("Redis client not initialized")
await self._client.flushdb()
self.logger.warning("redis_database_flushed")
except Exception as e:
self.logger.error("redis_flush_failed", error=str(e))
raise
# Global connection manager instance
_redis_manager: Optional[RedisConnectionManager] = None
async def get_redis_manager() -> RedisConnectionManager:
"""
Get or create global Redis manager instance
Returns:
RedisConnectionManager instance
"""
global _redis_manager
if _redis_manager is None:
_redis_manager = RedisConnectionManager()
return _redis_manager
async def initialize_redis(
redis_url: str,
db: int = 0,
max_connections: int = 50,
**kwargs
) -> redis.Redis:
"""
Initialize Redis and return client
Args:
redis_url: Redis connection URL
db: Database number
max_connections: Maximum connections in pool
**kwargs: Additional connection parameters
Returns:
Redis client instance
"""
manager = await get_redis_manager()
await manager.initialize(
redis_url=redis_url,
db=db,
max_connections=max_connections,
**kwargs
)
return manager.get_client()
async def get_redis_client() -> redis.Redis:
"""
Get initialized Redis client
Returns:
Redis client instance
Raises:
RuntimeError: If Redis not initialized
"""
manager = await get_redis_manager()
return manager.get_client()
async def close_redis():
"""Close Redis connections"""
global _redis_manager
if _redis_manager:
await _redis_manager.close()
_redis_manager = None
@asynccontextmanager
async def redis_context(redis_url: str, db: int = 0):
"""
Context manager for Redis connections
Usage:
async with redis_context(settings.REDIS_URL) as client:
await client.set("key", "value")
Args:
redis_url: Redis connection URL
db: Database number
Yields:
Redis client
"""
client = None
try:
client = await initialize_redis(redis_url, db=db)
yield client
finally:
if client:
await close_redis()
# Convenience functions for common operations
async def set_with_ttl(key: str, value: str, ttl: int) -> bool:
"""
Set key with TTL
Args:
key: Redis key
value: Value to set
ttl: Time to live in seconds
Returns:
bool: True if successful
"""
try:
client = await get_redis_client()
await client.setex(key, ttl, value)
return True
except Exception as e:
logger.error("redis_set_failed", key=key, error=str(e))
return False
async def get_value(key: str) -> Optional[str]:
"""
Get value by key
Args:
key: Redis key
Returns:
Value or None if not found
"""
try:
client = await get_redis_client()
return await client.get(key)
except Exception as e:
logger.error("redis_get_failed", key=key, error=str(e))
return None
async def increment_counter(key: str, amount: int = 1, ttl: Optional[int] = None) -> int:
"""
Increment counter with optional TTL
Args:
key: Redis key
amount: Amount to increment
ttl: Time to live in seconds (sets on first increment)
Returns:
New counter value
"""
try:
client = await get_redis_client()
new_value = await client.incrby(key, amount)
# Set TTL if specified and key is new (value == amount)
if ttl and new_value == amount:
await client.expire(key, ttl)
return new_value
except Exception as e:
logger.error("redis_increment_failed", key=key, error=str(e))
return 0
async def get_keys_pattern(pattern: str) -> list:
"""
Get keys matching pattern
Args:
pattern: Redis key pattern (e.g., "quota:*")
Returns:
List of matching keys
"""
try:
client = await get_redis_client()
return await client.keys(pattern)
except Exception as e:
logger.error("redis_keys_failed", pattern=pattern, error=str(e))
return []