Add role-based filtering and imporve code
This commit is contained in:
@@ -2,6 +2,6 @@
|
||||
|
||||
from .config import settings
|
||||
from .database import DatabaseManager, get_db
|
||||
from .redis_client import RedisClient, get_redis
|
||||
from .redis_wrapper import DemoRedisWrapper, get_redis
|
||||
|
||||
__all__ = ["settings", "DatabaseManager", "get_db", "RedisClient", "get_redis"]
|
||||
__all__ = ["settings", "DatabaseManager", "get_db", "DemoRedisWrapper", "get_redis"]
|
||||
|
||||
@@ -1,51 +1,25 @@
|
||||
"""
|
||||
Redis client for demo session data caching
|
||||
Redis wrapper for demo session service using shared Redis implementation
|
||||
Provides a compatibility layer for session-specific operations
|
||||
"""
|
||||
|
||||
import redis.asyncio as redis
|
||||
from typing import Optional, Any
|
||||
import json
|
||||
import structlog
|
||||
from datetime import timedelta
|
||||
|
||||
from .config import settings
|
||||
from typing import Optional, Any
|
||||
from shared.redis_utils import get_redis_client
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RedisClient:
|
||||
"""Redis client for session data"""
|
||||
class DemoRedisWrapper:
|
||||
"""Wrapper around shared Redis client for demo session operations"""
|
||||
|
||||
def __init__(self, redis_url: str = None):
|
||||
self.redis_url = redis_url or settings.REDIS_URL
|
||||
self.client: Optional[redis.Redis] = None
|
||||
self.key_prefix = settings.REDIS_KEY_PREFIX
|
||||
def __init__(self, key_prefix: str = "demo_session"):
|
||||
self.key_prefix = key_prefix
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to Redis"""
|
||||
if not self.client:
|
||||
self.client = await redis.from_url(
|
||||
self.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True
|
||||
)
|
||||
logger.info("Redis client connected", redis_url=self.redis_url.split("@")[-1])
|
||||
|
||||
async def close(self):
|
||||
"""Close Redis connection"""
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
logger.info("Redis connection closed")
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""Check Redis connection"""
|
||||
try:
|
||||
if not self.client:
|
||||
await self.connect()
|
||||
return await self.client.ping()
|
||||
except Exception as e:
|
||||
logger.error("Redis ping failed", error=str(e))
|
||||
return False
|
||||
async def get_client(self):
|
||||
"""Get the underlying Redis client"""
|
||||
return await get_redis_client()
|
||||
|
||||
def _make_key(self, *parts: str) -> str:
|
||||
"""Create Redis key with prefix"""
|
||||
@@ -53,26 +27,22 @@ class RedisClient:
|
||||
|
||||
async def set_session_data(self, session_id: str, key: str, data: Any, ttl: int = None):
|
||||
"""Store session data in Redis"""
|
||||
if not self.client:
|
||||
await self.connect()
|
||||
|
||||
client = await get_redis_client()
|
||||
redis_key = self._make_key(session_id, key)
|
||||
serialized = json.dumps(data) if not isinstance(data, str) else data
|
||||
|
||||
if ttl:
|
||||
await self.client.setex(redis_key, ttl, serialized)
|
||||
await client.setex(redis_key, ttl, serialized)
|
||||
else:
|
||||
await self.client.set(redis_key, serialized)
|
||||
await client.set(redis_key, serialized)
|
||||
|
||||
logger.debug("Session data stored", session_id=session_id, key=key)
|
||||
|
||||
async def get_session_data(self, session_id: str, key: str) -> Optional[Any]:
|
||||
"""Retrieve session data from Redis"""
|
||||
if not self.client:
|
||||
await self.connect()
|
||||
|
||||
client = await get_redis_client()
|
||||
redis_key = self._make_key(session_id, key)
|
||||
data = await self.client.get(redis_key)
|
||||
data = await client.get(redis_key)
|
||||
|
||||
if data:
|
||||
try:
|
||||
@@ -84,49 +54,42 @@ class RedisClient:
|
||||
|
||||
async def delete_session_data(self, session_id: str, key: str = None):
|
||||
"""Delete session data"""
|
||||
if not self.client:
|
||||
await self.connect()
|
||||
client = await get_redis_client()
|
||||
|
||||
if key:
|
||||
redis_key = self._make_key(session_id, key)
|
||||
await self.client.delete(redis_key)
|
||||
await client.delete(redis_key)
|
||||
else:
|
||||
pattern = self._make_key(session_id, "*")
|
||||
keys = await self.client.keys(pattern)
|
||||
keys = await client.keys(pattern)
|
||||
if keys:
|
||||
await self.client.delete(*keys)
|
||||
await client.delete(*keys)
|
||||
|
||||
logger.debug("Session data deleted", session_id=session_id, key=key)
|
||||
|
||||
async def extend_session_ttl(self, session_id: str, ttl: int):
|
||||
"""Extend TTL for all session keys"""
|
||||
if not self.client:
|
||||
await self.connect()
|
||||
|
||||
client = await get_redis_client()
|
||||
pattern = self._make_key(session_id, "*")
|
||||
keys = await self.client.keys(pattern)
|
||||
keys = await client.keys(pattern)
|
||||
|
||||
for key in keys:
|
||||
await self.client.expire(key, ttl)
|
||||
await client.expire(key, ttl)
|
||||
|
||||
logger.debug("Session TTL extended", session_id=session_id, ttl=ttl)
|
||||
|
||||
async def set_hash(self, session_id: str, hash_key: str, field: str, value: Any):
|
||||
"""Store hash field in Redis"""
|
||||
if not self.client:
|
||||
await self.connect()
|
||||
|
||||
client = await get_redis_client()
|
||||
redis_key = self._make_key(session_id, hash_key)
|
||||
serialized = json.dumps(value) if not isinstance(value, str) else value
|
||||
await self.client.hset(redis_key, field, serialized)
|
||||
await client.hset(redis_key, field, serialized)
|
||||
|
||||
async def get_hash(self, session_id: str, hash_key: str, field: str) -> Optional[Any]:
|
||||
"""Get hash field from Redis"""
|
||||
if not self.client:
|
||||
await self.connect()
|
||||
|
||||
client = await get_redis_client()
|
||||
redis_key = self._make_key(session_id, hash_key)
|
||||
data = await self.client.hget(redis_key, field)
|
||||
data = await client.hget(redis_key, field)
|
||||
|
||||
if data:
|
||||
try:
|
||||
@@ -138,11 +101,9 @@ class RedisClient:
|
||||
|
||||
async def get_all_hash(self, session_id: str, hash_key: str) -> dict:
|
||||
"""Get all hash fields"""
|
||||
if not self.client:
|
||||
await self.connect()
|
||||
|
||||
client = await get_redis_client()
|
||||
redis_key = self._make_key(session_id, hash_key)
|
||||
data = await self.client.hgetall(redis_key)
|
||||
data = await client.hgetall(redis_key)
|
||||
|
||||
result = {}
|
||||
for field, value in data.items():
|
||||
@@ -153,12 +114,18 @@ class RedisClient:
|
||||
|
||||
return result
|
||||
|
||||
|
||||
redis_client = RedisClient()
|
||||
async def get_client(self):
|
||||
"""Get raw Redis client for direct operations"""
|
||||
return await get_redis_client()
|
||||
|
||||
|
||||
async def get_redis() -> RedisClient:
|
||||
"""Dependency for FastAPI"""
|
||||
if not redis_client.client:
|
||||
await redis_client.connect()
|
||||
return redis_client
|
||||
# Cached instance
|
||||
_redis_wrapper = None
|
||||
|
||||
|
||||
async def get_redis() -> DemoRedisWrapper:
|
||||
"""Dependency for FastAPI - returns wrapper around shared Redis"""
|
||||
global _redis_wrapper
|
||||
if _redis_wrapper is None:
|
||||
_redis_wrapper = DemoRedisWrapper()
|
||||
return _redis_wrapper
|
||||
Reference in New Issue
Block a user