Files
bakery-ia/shared/redis_utils/client.py
2025-10-15 16:12:49 +02:00

330 lines
8.3 KiB
Python

"""
Redis client initialization and connection management
Provides standardized Redis connection for all services
"""
import redis.asyncio as redis
from typing import Optional
import structlog
from contextlib import asynccontextmanager
logger = structlog.get_logger()
class RedisConnectionManager:
"""
Manages Redis connections with connection pooling and error handling
Thread-safe singleton pattern for sharing connections across service
"""
def __init__(self):
self._client: Optional[redis.Redis] = None
self._pool: Optional[redis.ConnectionPool] = None
self.logger = logger
async def initialize(
self,
redis_url: str,
db: int = 0,
max_connections: int = 50,
decode_responses: bool = True,
retry_on_timeout: bool = True,
socket_keepalive: bool = True,
health_check_interval: int = 30
):
"""
Initialize Redis connection with pool
Args:
redis_url: Redis connection URL (redis://[:password]@host:port)
db: Database number (0-15)
max_connections: Maximum connections in pool
decode_responses: Automatically decode responses to strings
retry_on_timeout: Retry on timeout errors
socket_keepalive: Enable TCP keepalive
health_check_interval: Health check interval in seconds
"""
try:
# Create connection pool
self._pool = redis.ConnectionPool.from_url(
redis_url,
db=db,
max_connections=max_connections,
decode_responses=decode_responses,
retry_on_timeout=retry_on_timeout,
socket_keepalive=socket_keepalive,
health_check_interval=health_check_interval
)
# Create Redis client with pool
self._client = redis.Redis(connection_pool=self._pool)
# Test connection
await self._client.ping()
self.logger.info(
"redis_initialized",
redis_url=redis_url.split("@")[-1], # Log only host:port, not password
db=db,
max_connections=max_connections
)
except Exception as e:
self.logger.error(
"redis_initialization_failed",
error=str(e),
redis_url=redis_url.split("@")[-1]
)
raise
async def close(self):
"""Close Redis connection and pool"""
if self._client:
await self._client.close()
self.logger.info("redis_client_closed")
if self._pool:
await self._pool.disconnect()
self.logger.info("redis_pool_closed")
def get_client(self) -> redis.Redis:
"""
Get Redis client instance
Returns:
Redis client
Raises:
RuntimeError: If client not initialized
"""
if self._client is None:
raise RuntimeError("Redis client not initialized. Call initialize() first.")
return self._client
async def health_check(self) -> bool:
"""
Check Redis connection health
Returns:
bool: True if healthy, False otherwise
"""
try:
if self._client is None:
return False
await self._client.ping()
return True
except Exception as e:
self.logger.error("redis_health_check_failed", error=str(e))
return False
async def get_info(self) -> dict:
"""
Get Redis server information
Returns:
dict: Redis INFO command output
"""
try:
if self._client is None:
return {}
return await self._client.info()
except Exception as e:
self.logger.error("redis_info_failed", error=str(e))
return {}
async def flush_db(self):
"""
Flush current database (USE WITH CAUTION)
Only for development/testing
"""
try:
if self._client is None:
raise RuntimeError("Redis client not initialized")
await self._client.flushdb()
self.logger.warning("redis_database_flushed")
except Exception as e:
self.logger.error("redis_flush_failed", error=str(e))
raise
# Global connection manager instance
_redis_manager: Optional[RedisConnectionManager] = None
async def get_redis_manager() -> RedisConnectionManager:
"""
Get or create global Redis manager instance
Returns:
RedisConnectionManager instance
"""
global _redis_manager
if _redis_manager is None:
_redis_manager = RedisConnectionManager()
return _redis_manager
async def initialize_redis(
redis_url: str,
db: int = 0,
max_connections: int = 50,
**kwargs
) -> redis.Redis:
"""
Initialize Redis and return client
Args:
redis_url: Redis connection URL
db: Database number
max_connections: Maximum connections in pool
**kwargs: Additional connection parameters
Returns:
Redis client instance
"""
manager = await get_redis_manager()
await manager.initialize(
redis_url=redis_url,
db=db,
max_connections=max_connections,
**kwargs
)
return manager.get_client()
async def get_redis_client() -> redis.Redis:
"""
Get initialized Redis client
Returns:
Redis client instance
Raises:
RuntimeError: If Redis not initialized
"""
manager = await get_redis_manager()
return manager.get_client()
async def close_redis():
"""Close Redis connections"""
global _redis_manager
if _redis_manager:
await _redis_manager.close()
_redis_manager = None
@asynccontextmanager
async def redis_context(redis_url: str, db: int = 0):
"""
Context manager for Redis connections
Usage:
async with redis_context(settings.REDIS_URL) as client:
await client.set("key", "value")
Args:
redis_url: Redis connection URL
db: Database number
Yields:
Redis client
"""
client = None
try:
client = await initialize_redis(redis_url, db=db)
yield client
finally:
if client:
await close_redis()
# Convenience functions for common operations
async def set_with_ttl(key: str, value: str, ttl: int) -> bool:
"""
Set key with TTL
Args:
key: Redis key
value: Value to set
ttl: Time to live in seconds
Returns:
bool: True if successful
"""
try:
client = await get_redis_client()
await client.setex(key, ttl, value)
return True
except Exception as e:
logger.error("redis_set_failed", key=key, error=str(e))
return False
async def get_value(key: str) -> Optional[str]:
"""
Get value by key
Args:
key: Redis key
Returns:
Value or None if not found
"""
try:
client = await get_redis_client()
return await client.get(key)
except Exception as e:
logger.error("redis_get_failed", key=key, error=str(e))
return None
async def increment_counter(key: str, amount: int = 1, ttl: Optional[int] = None) -> int:
"""
Increment counter with optional TTL
Args:
key: Redis key
amount: Amount to increment
ttl: Time to live in seconds (sets on first increment)
Returns:
New counter value
"""
try:
client = await get_redis_client()
new_value = await client.incrby(key, amount)
# Set TTL if specified and key is new (value == amount)
if ttl and new_value == amount:
await client.expire(key, ttl)
return new_value
except Exception as e:
logger.error("redis_increment_failed", key=key, error=str(e))
return 0
async def get_keys_pattern(pattern: str) -> list:
"""
Get keys matching pattern
Args:
pattern: Redis key pattern (e.g., "quota:*")
Returns:
List of matching keys
"""
try:
client = await get_redis_client()
return await client.keys(pattern)
except Exception as e:
logger.error("redis_keys_failed", pattern=pattern, error=str(e))
return []