Add role-based filtering and imporve code

This commit is contained in:
Urtzi Alfaro
2025-10-15 16:12:49 +02:00
parent 96ad5c6692
commit 8f9e9a7edc
158 changed files with 11033 additions and 1544 deletions

View File

@@ -2,6 +2,6 @@
from .config import settings
from .database import DatabaseManager, get_db
from .redis_client import RedisClient, get_redis
from .redis_wrapper import DemoRedisWrapper, get_redis
__all__ = ["settings", "DatabaseManager", "get_db", "RedisClient", "get_redis"]
__all__ = ["settings", "DatabaseManager", "get_db", "DemoRedisWrapper", "get_redis"]

View File

@@ -1,51 +1,25 @@
"""
Redis client for demo session data caching
Redis wrapper for demo session service using shared Redis implementation
Provides a compatibility layer for session-specific operations
"""
import redis.asyncio as redis
from typing import Optional, Any
import json
import structlog
from datetime import timedelta
from .config import settings
from typing import Optional, Any
from shared.redis_utils import get_redis_client
logger = structlog.get_logger()
class RedisClient:
"""Redis client for session data"""
class DemoRedisWrapper:
"""Wrapper around shared Redis client for demo session operations"""
def __init__(self, redis_url: str = None):
self.redis_url = redis_url or settings.REDIS_URL
self.client: Optional[redis.Redis] = None
self.key_prefix = settings.REDIS_KEY_PREFIX
def __init__(self, key_prefix: str = "demo_session"):
self.key_prefix = key_prefix
async def connect(self):
"""Connect to Redis"""
if not self.client:
self.client = await redis.from_url(
self.redis_url,
encoding="utf-8",
decode_responses=True
)
logger.info("Redis client connected", redis_url=self.redis_url.split("@")[-1])
async def close(self):
"""Close Redis connection"""
if self.client:
await self.client.close()
logger.info("Redis connection closed")
async def ping(self) -> bool:
"""Check Redis connection"""
try:
if not self.client:
await self.connect()
return await self.client.ping()
except Exception as e:
logger.error("Redis ping failed", error=str(e))
return False
async def get_client(self):
"""Get the underlying Redis client"""
return await get_redis_client()
def _make_key(self, *parts: str) -> str:
"""Create Redis key with prefix"""
@@ -53,26 +27,22 @@ class RedisClient:
async def set_session_data(self, session_id: str, key: str, data: Any, ttl: int = None):
"""Store session data in Redis"""
if not self.client:
await self.connect()
client = await get_redis_client()
redis_key = self._make_key(session_id, key)
serialized = json.dumps(data) if not isinstance(data, str) else data
if ttl:
await self.client.setex(redis_key, ttl, serialized)
await client.setex(redis_key, ttl, serialized)
else:
await self.client.set(redis_key, serialized)
await client.set(redis_key, serialized)
logger.debug("Session data stored", session_id=session_id, key=key)
async def get_session_data(self, session_id: str, key: str) -> Optional[Any]:
"""Retrieve session data from Redis"""
if not self.client:
await self.connect()
client = await get_redis_client()
redis_key = self._make_key(session_id, key)
data = await self.client.get(redis_key)
data = await client.get(redis_key)
if data:
try:
@@ -84,49 +54,42 @@ class RedisClient:
async def delete_session_data(self, session_id: str, key: str = None):
"""Delete session data"""
if not self.client:
await self.connect()
client = await get_redis_client()
if key:
redis_key = self._make_key(session_id, key)
await self.client.delete(redis_key)
await client.delete(redis_key)
else:
pattern = self._make_key(session_id, "*")
keys = await self.client.keys(pattern)
keys = await client.keys(pattern)
if keys:
await self.client.delete(*keys)
await client.delete(*keys)
logger.debug("Session data deleted", session_id=session_id, key=key)
async def extend_session_ttl(self, session_id: str, ttl: int):
"""Extend TTL for all session keys"""
if not self.client:
await self.connect()
client = await get_redis_client()
pattern = self._make_key(session_id, "*")
keys = await self.client.keys(pattern)
keys = await client.keys(pattern)
for key in keys:
await self.client.expire(key, ttl)
await client.expire(key, ttl)
logger.debug("Session TTL extended", session_id=session_id, ttl=ttl)
async def set_hash(self, session_id: str, hash_key: str, field: str, value: Any):
"""Store hash field in Redis"""
if not self.client:
await self.connect()
client = await get_redis_client()
redis_key = self._make_key(session_id, hash_key)
serialized = json.dumps(value) if not isinstance(value, str) else value
await self.client.hset(redis_key, field, serialized)
await client.hset(redis_key, field, serialized)
async def get_hash(self, session_id: str, hash_key: str, field: str) -> Optional[Any]:
"""Get hash field from Redis"""
if not self.client:
await self.connect()
client = await get_redis_client()
redis_key = self._make_key(session_id, hash_key)
data = await self.client.hget(redis_key, field)
data = await client.hget(redis_key, field)
if data:
try:
@@ -138,11 +101,9 @@ class RedisClient:
async def get_all_hash(self, session_id: str, hash_key: str) -> dict:
"""Get all hash fields"""
if not self.client:
await self.connect()
client = await get_redis_client()
redis_key = self._make_key(session_id, hash_key)
data = await self.client.hgetall(redis_key)
data = await client.hgetall(redis_key)
result = {}
for field, value in data.items():
@@ -153,12 +114,18 @@ class RedisClient:
return result
redis_client = RedisClient()
async def get_client(self):
"""Get raw Redis client for direct operations"""
return await get_redis_client()
async def get_redis() -> RedisClient:
"""Dependency for FastAPI"""
if not redis_client.client:
await redis_client.connect()
return redis_client
# Cached instance
_redis_wrapper = None
async def get_redis() -> DemoRedisWrapper:
"""Dependency for FastAPI - returns wrapper around shared Redis"""
global _redis_wrapper
if _redis_wrapper is None:
_redis_wrapper = DemoRedisWrapper()
return _redis_wrapper