Add role-based filtering and imporve code

This commit is contained in:
Urtzi Alfaro
2025-10-15 16:12:49 +02:00
parent 96ad5c6692
commit 8f9e9a7edc
158 changed files with 11033 additions and 1544 deletions

View File

@@ -8,7 +8,7 @@ import jwt
from app.api.schemas import DemoSessionResponse, DemoSessionStats
from app.services import DemoSessionManager, DemoCleanupService
from app.core import get_db, get_redis, RedisClient
from app.core import get_db, get_redis, DemoRedisWrapper
from sqlalchemy.ext.asyncio import AsyncSession
from shared.routing import RouteBuilder
@@ -25,7 +25,7 @@ route_builder = RouteBuilder('demo')
async def extend_demo_session(
session_id: str = Path(...),
db: AsyncSession = Depends(get_db),
redis: RedisClient = Depends(get_redis)
redis: DemoRedisWrapper = Depends(get_redis)
):
"""Extend demo session expiration (BUSINESS OPERATION)"""
try:
@@ -67,7 +67,7 @@ async def extend_demo_session(
)
async def get_demo_stats(
db: AsyncSession = Depends(get_db),
redis: RedisClient = Depends(get_redis)
redis: DemoRedisWrapper = Depends(get_redis)
):
"""Get demo session statistics (BUSINESS OPERATION)"""
session_manager = DemoSessionManager(db, redis)
@@ -81,7 +81,7 @@ async def get_demo_stats(
)
async def run_cleanup(
db: AsyncSession = Depends(get_db),
redis: RedisClient = Depends(get_redis)
redis: DemoRedisWrapper = Depends(get_redis)
):
"""Manually trigger session cleanup (BUSINESS OPERATION - Internal endpoint for CronJob)"""
cleanup_service = DemoCleanupService(db, redis)

View File

@@ -10,7 +10,8 @@ import jwt
from app.api.schemas import DemoSessionCreate, DemoSessionResponse
from app.services import DemoSessionManager
from app.core import get_db, get_redis, RedisClient
from app.core import get_db
from app.core.redis_wrapper import get_redis, DemoRedisWrapper
from sqlalchemy.ext.asyncio import AsyncSession
from shared.routing import RouteBuilder
@@ -64,7 +65,7 @@ async def create_demo_session(
request: DemoSessionCreate,
http_request: Request,
db: AsyncSession = Depends(get_db),
redis: RedisClient = Depends(get_redis)
redis: DemoRedisWrapper = Depends(get_redis)
):
"""Create a new isolated demo session (ATOMIC)"""
logger.info("Creating demo session", demo_account_type=request.demo_account_type)
@@ -130,7 +131,7 @@ async def create_demo_session(
async def get_session_info(
session_id: str = Path(...),
db: AsyncSession = Depends(get_db),
redis: RedisClient = Depends(get_redis)
redis: DemoRedisWrapper = Depends(get_redis)
):
"""Get demo session information (ATOMIC READ)"""
session_manager = DemoSessionManager(db, redis)
@@ -149,7 +150,7 @@ async def get_session_info(
async def get_session_status(
session_id: str = Path(...),
db: AsyncSession = Depends(get_db),
redis: RedisClient = Depends(get_redis)
redis: DemoRedisWrapper = Depends(get_redis)
):
"""
Get demo session provisioning status
@@ -173,7 +174,7 @@ async def get_session_status(
async def retry_session_cloning(
session_id: str = Path(...),
db: AsyncSession = Depends(get_db),
redis: RedisClient = Depends(get_redis)
redis: DemoRedisWrapper = Depends(get_redis)
):
"""
Retry failed cloning operations
@@ -204,7 +205,7 @@ async def retry_session_cloning(
async def destroy_demo_session(
session_id: str = Path(...),
db: AsyncSession = Depends(get_db),
redis: RedisClient = Depends(get_redis)
redis: DemoRedisWrapper = Depends(get_redis)
):
"""Destroy demo session and cleanup resources (ATOMIC DELETE)"""
try:
@@ -225,7 +226,7 @@ async def destroy_demo_session(
async def destroy_demo_session_post(
session_id: str = Path(...),
db: AsyncSession = Depends(get_db),
redis: RedisClient = Depends(get_redis)
redis: DemoRedisWrapper = Depends(get_redis)
):
"""Destroy demo session via POST (for frontend compatibility)"""
try:

View File

@@ -2,6 +2,6 @@
from .config import settings
from .database import DatabaseManager, get_db
from .redis_client import RedisClient, get_redis
from .redis_wrapper import DemoRedisWrapper, get_redis
__all__ = ["settings", "DatabaseManager", "get_db", "RedisClient", "get_redis"]
__all__ = ["settings", "DatabaseManager", "get_db", "DemoRedisWrapper", "get_redis"]

View File

@@ -1,51 +1,25 @@
"""
Redis client for demo session data caching
Redis wrapper for demo session service using shared Redis implementation
Provides a compatibility layer for session-specific operations
"""
import redis.asyncio as redis
from typing import Optional, Any
import json
import structlog
from datetime import timedelta
from .config import settings
from typing import Optional, Any
from shared.redis_utils import get_redis_client
logger = structlog.get_logger()
class RedisClient:
"""Redis client for session data"""
class DemoRedisWrapper:
"""Wrapper around shared Redis client for demo session operations"""
def __init__(self, redis_url: str = None):
self.redis_url = redis_url or settings.REDIS_URL
self.client: Optional[redis.Redis] = None
self.key_prefix = settings.REDIS_KEY_PREFIX
def __init__(self, key_prefix: str = "demo_session"):
self.key_prefix = key_prefix
async def connect(self):
"""Connect to Redis"""
if not self.client:
self.client = await redis.from_url(
self.redis_url,
encoding="utf-8",
decode_responses=True
)
logger.info("Redis client connected", redis_url=self.redis_url.split("@")[-1])
async def close(self):
"""Close Redis connection"""
if self.client:
await self.client.close()
logger.info("Redis connection closed")
async def ping(self) -> bool:
"""Check Redis connection"""
try:
if not self.client:
await self.connect()
return await self.client.ping()
except Exception as e:
logger.error("Redis ping failed", error=str(e))
return False
async def get_client(self):
"""Get the underlying Redis client"""
return await get_redis_client()
def _make_key(self, *parts: str) -> str:
"""Create Redis key with prefix"""
@@ -53,26 +27,22 @@ class RedisClient:
async def set_session_data(self, session_id: str, key: str, data: Any, ttl: int = None):
"""Store session data in Redis"""
if not self.client:
await self.connect()
client = await get_redis_client()
redis_key = self._make_key(session_id, key)
serialized = json.dumps(data) if not isinstance(data, str) else data
if ttl:
await self.client.setex(redis_key, ttl, serialized)
await client.setex(redis_key, ttl, serialized)
else:
await self.client.set(redis_key, serialized)
await client.set(redis_key, serialized)
logger.debug("Session data stored", session_id=session_id, key=key)
async def get_session_data(self, session_id: str, key: str) -> Optional[Any]:
"""Retrieve session data from Redis"""
if not self.client:
await self.connect()
client = await get_redis_client()
redis_key = self._make_key(session_id, key)
data = await self.client.get(redis_key)
data = await client.get(redis_key)
if data:
try:
@@ -84,49 +54,42 @@ class RedisClient:
async def delete_session_data(self, session_id: str, key: str = None):
"""Delete session data"""
if not self.client:
await self.connect()
client = await get_redis_client()
if key:
redis_key = self._make_key(session_id, key)
await self.client.delete(redis_key)
await client.delete(redis_key)
else:
pattern = self._make_key(session_id, "*")
keys = await self.client.keys(pattern)
keys = await client.keys(pattern)
if keys:
await self.client.delete(*keys)
await client.delete(*keys)
logger.debug("Session data deleted", session_id=session_id, key=key)
async def extend_session_ttl(self, session_id: str, ttl: int):
"""Extend TTL for all session keys"""
if not self.client:
await self.connect()
client = await get_redis_client()
pattern = self._make_key(session_id, "*")
keys = await self.client.keys(pattern)
keys = await client.keys(pattern)
for key in keys:
await self.client.expire(key, ttl)
await client.expire(key, ttl)
logger.debug("Session TTL extended", session_id=session_id, ttl=ttl)
async def set_hash(self, session_id: str, hash_key: str, field: str, value: Any):
"""Store hash field in Redis"""
if not self.client:
await self.connect()
client = await get_redis_client()
redis_key = self._make_key(session_id, hash_key)
serialized = json.dumps(value) if not isinstance(value, str) else value
await self.client.hset(redis_key, field, serialized)
await client.hset(redis_key, field, serialized)
async def get_hash(self, session_id: str, hash_key: str, field: str) -> Optional[Any]:
"""Get hash field from Redis"""
if not self.client:
await self.connect()
client = await get_redis_client()
redis_key = self._make_key(session_id, hash_key)
data = await self.client.hget(redis_key, field)
data = await client.hget(redis_key, field)
if data:
try:
@@ -138,11 +101,9 @@ class RedisClient:
async def get_all_hash(self, session_id: str, hash_key: str) -> dict:
"""Get all hash fields"""
if not self.client:
await self.connect()
client = await get_redis_client()
redis_key = self._make_key(session_id, hash_key)
data = await self.client.hgetall(redis_key)
data = await client.hgetall(redis_key)
result = {}
for field, value in data.items():
@@ -153,12 +114,18 @@ class RedisClient:
return result
redis_client = RedisClient()
async def get_client(self):
"""Get raw Redis client for direct operations"""
return await get_redis_client()
async def get_redis() -> RedisClient:
"""Dependency for FastAPI"""
if not redis_client.client:
await redis_client.connect()
return redis_client
# Cached instance
_redis_wrapper = None
async def get_redis() -> DemoRedisWrapper:
"""Dependency for FastAPI - returns wrapper around shared Redis"""
global _redis_wrapper
if _redis_wrapper is None:
_redis_wrapper = DemoRedisWrapper()
return _redis_wrapper

View File

@@ -9,14 +9,14 @@ from fastapi.responses import JSONResponse
import structlog
from contextlib import asynccontextmanager
from app.core import settings, DatabaseManager, RedisClient
from app.core import settings, DatabaseManager
from app.api import demo_sessions, demo_accounts, demo_operations
from shared.redis_utils import initialize_redis, close_redis
logger = structlog.get_logger()
# Initialize database and redis
# Initialize database
db_manager = DatabaseManager()
redis_client = RedisClient()
@asynccontextmanager
@@ -27,8 +27,12 @@ async def lifespan(app: FastAPI):
# Initialize database
db_manager.initialize()
# Connect to Redis
await redis_client.connect()
# Initialize Redis using shared implementation
await initialize_redis(
redis_url=settings.REDIS_URL,
db=0,
max_connections=50
)
logger.info("Demo Session Service started successfully")
@@ -36,7 +40,7 @@ async def lifespan(app: FastAPI):
# Cleanup on shutdown
await db_manager.close()
await redis_client.close()
await close_redis()
logger.info("Demo Session Service stopped")
@@ -92,7 +96,10 @@ async def root():
@app.get("/health")
async def health():
"""Health check endpoint"""
redis_ok = await redis_client.ping()
from shared.redis_utils import get_redis_manager
redis_manager = await get_redis_manager()
redis_ok = await redis_manager.health_check()
return {
"status": "healthy" if redis_ok else "degraded",

View File

@@ -1,5 +1,12 @@
# Import AuditLog model for this service
from shared.security import create_audit_log_model
from shared.database.base import Base
# Create audit log model for this service
AuditLog = create_audit_log_model(Base)
"""Demo Session Service Models"""
from .demo_session import DemoSession, DemoSessionStatus, CloningStatus
__all__ = ["DemoSession", "DemoSessionStatus", "CloningStatus"]
__all__ = ["DemoSession", "DemoSessionStatus", "CloningStatus", "AuditLog"]

View File

@@ -11,7 +11,7 @@ import structlog
from app.models import DemoSession, DemoSessionStatus
from app.services.data_cloner import DemoDataCloner
from app.core import RedisClient
from app.core.redis_wrapper import DemoRedisWrapper
logger = structlog.get_logger()
@@ -19,7 +19,7 @@ logger = structlog.get_logger()
class DemoCleanupService:
"""Handles cleanup of expired demo sessions"""
def __init__(self, db: AsyncSession, redis: RedisClient):
def __init__(self, db: AsyncSession, redis: DemoRedisWrapper):
self.db = db
self.redis = redis
self.data_cloner = DemoDataCloner(db, redis)

View File

@@ -9,7 +9,8 @@ import httpx
import structlog
import uuid
from app.core import RedisClient, settings
from app.core.redis_wrapper import DemoRedisWrapper
from app.core import settings
logger = structlog.get_logger()
@@ -17,7 +18,7 @@ logger = structlog.get_logger()
class DemoDataCloner:
"""Clones demo data for isolated sessions"""
def __init__(self, db: AsyncSession, redis: RedisClient):
def __init__(self, db: AsyncSession, redis: DemoRedisWrapper):
self.db = db
self.redis = redis

View File

@@ -12,7 +12,8 @@ import secrets
import structlog
from app.models import DemoSession, DemoSessionStatus, CloningStatus
from app.core import RedisClient, settings
from app.core.redis_wrapper import DemoRedisWrapper
from app.core import settings
from app.services.clone_orchestrator import CloneOrchestrator
logger = structlog.get_logger()
@@ -21,7 +22,7 @@ logger = structlog.get_logger()
class DemoSessionManager:
"""Manages demo session lifecycle"""
def __init__(self, db: AsyncSession, redis: RedisClient):
def __init__(self, db: AsyncSession, redis: DemoRedisWrapper):
self.db = db
self.redis = redis
self.orchestrator = CloneOrchestrator()
@@ -367,7 +368,8 @@ class DemoSessionManager:
}
import json as json_module
await self.redis.client.setex(
client = await self.redis.get_client()
await client.setex(
status_key,
7200, # Cache for 2 hours
json_module.dumps(status_data) # Convert to JSON string
@@ -385,7 +387,8 @@ class DemoSessionManager:
"""
# Try Redis cache first
status_key = f"session:{session_id}:status"
cached = await self.redis.client.get(status_key)
client = await self.redis.get_client()
cached = await client.get(status_key)
if cached:
import json