385 lines
17 KiB
Python
385 lines
17 KiB
Python
"""
|
|
Demo Session Middleware
|
|
Handles demo account restrictions and virtual tenant injection
|
|
"""
|
|
|
|
from fastapi import Request, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import Response
|
|
from typing import Optional
|
|
import uuid
|
|
import httpx
|
|
import structlog
|
|
import json
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
# Fixed Demo Tenant IDs (these are the template tenants that will be cloned)
|
|
# Professional demo (merged from San Pablo + La Espiga)
|
|
DEMO_TENANT_PROFESSIONAL = uuid.UUID("a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6")
|
|
|
|
# Enterprise chain demo (parent + 3 children)
|
|
DEMO_TENANT_ENTERPRISE_CHAIN = uuid.UUID("c3d4e5f6-a7b8-49c0-d1e2-f3a4b5c6d7e8")
|
|
DEMO_TENANT_CHILD_1 = uuid.UUID("d4e5f6a7-b8c9-40d1-e2f3-a4b5c6d7e8f9")
|
|
DEMO_TENANT_CHILD_2 = uuid.UUID("e5f6a7b8-c9d0-41e2-f3a4-b5c6d7e8f9a0")
|
|
DEMO_TENANT_CHILD_3 = uuid.UUID("f6a7b8c9-d0e1-42f3-a4b5-c6d7e8f9a0b1")
|
|
|
|
# Demo tenant IDs (base templates)
|
|
DEMO_TENANT_IDS = {
|
|
str(DEMO_TENANT_PROFESSIONAL), # Professional demo tenant
|
|
str(DEMO_TENANT_ENTERPRISE_CHAIN), # Enterprise chain parent
|
|
str(DEMO_TENANT_CHILD_1), # Enterprise chain child 1
|
|
str(DEMO_TENANT_CHILD_2), # Enterprise chain child 2
|
|
str(DEMO_TENANT_CHILD_3), # Enterprise chain child 3
|
|
}
|
|
|
|
# Demo user IDs - Maps demo account type to actual user UUIDs from fixture files
|
|
# These IDs are the owner IDs from the respective 01-tenant.json files
|
|
DEMO_USER_IDS = {
|
|
"professional": "c1a2b3c4-d5e6-47a8-b9c0-d1e2f3a4b5c6", # María García López (professional/01-tenant.json -> owner.id)
|
|
"enterprise": "d2e3f4a5-b6c7-48d9-e0f1-a2b3c4d5e6f7" # Director (enterprise/parent/01-tenant.json -> owner.id)
|
|
}
|
|
|
|
# Allowed operations for demo accounts (limited write)
|
|
DEMO_ALLOWED_OPERATIONS = {
|
|
# Read operations - all allowed
|
|
"GET": ["*"],
|
|
|
|
# Limited write operations for realistic testing
|
|
"POST": [
|
|
"/api/v1/pos/sales",
|
|
"/api/v1/pos/sessions",
|
|
"/api/v1/orders",
|
|
"/api/v1/inventory/adjustments",
|
|
"/api/v1/sales",
|
|
"/api/v1/production/batches",
|
|
"/api/v1/tenants/batch/sales-summary",
|
|
"/api/v1/tenants/batch/production-summary",
|
|
"/api/v1/auth/me/onboarding/complete", # Allow completing onboarding (no-op for demos)
|
|
"/api/v1/tenants/*/notifications/send", # Allow notifications (ML insights, alerts, etc.)
|
|
# Note: Forecast generation is explicitly blocked (see DEMO_BLOCKED_PATHS)
|
|
],
|
|
|
|
"PUT": [
|
|
"/api/v1/pos/sales/*",
|
|
"/api/v1/orders/*",
|
|
"/api/v1/inventory/stock/*",
|
|
"/api/v1/auth/me/onboarding/step", # Allow onboarding step updates (no-op for demos)
|
|
],
|
|
|
|
# Blocked operations
|
|
"DELETE": [], # No deletes allowed
|
|
"PATCH": [], # No patches allowed
|
|
}
|
|
|
|
# Explicitly blocked paths for demo accounts (even if method would be allowed)
|
|
# These require trained AI models which demo tenants don't have
|
|
DEMO_BLOCKED_PATHS = [
|
|
"/api/v1/forecasts/single",
|
|
"/api/v1/forecasts/multi-day",
|
|
"/api/v1/forecasts/batch",
|
|
]
|
|
|
|
DEMO_BLOCKED_PATH_MESSAGE = {
|
|
"forecasts": {
|
|
"message": "La generación de pronósticos no está disponible para cuentas demo. "
|
|
"Las cuentas demo no tienen modelos de IA entrenados.",
|
|
"message_en": "Forecast generation is not available for demo accounts. "
|
|
"Demo accounts do not have trained AI models.",
|
|
}
|
|
}
|
|
|
|
|
|
class DemoMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware to handle demo session logic with Redis caching"""
|
|
|
|
def __init__(self, app, demo_session_url: str = "http://demo-session-service:8000"):
|
|
super().__init__(app)
|
|
self.demo_session_url = demo_session_url
|
|
self._redis_client = None
|
|
|
|
async def _get_redis_client(self):
|
|
"""Get or lazily initialize Redis client"""
|
|
if self._redis_client is None:
|
|
try:
|
|
from shared.redis_utils import get_redis_client
|
|
self._redis_client = await get_redis_client()
|
|
logger.debug("Demo middleware: Redis client initialized")
|
|
except Exception as e:
|
|
logger.warning(f"Demo middleware: Failed to get Redis client: {e}. Caching disabled.")
|
|
self._redis_client = False # Sentinel value to avoid retrying
|
|
|
|
return self._redis_client if self._redis_client is not False else None
|
|
|
|
async def _get_cached_session(self, session_id: str) -> Optional[dict]:
|
|
"""Get session info from Redis cache"""
|
|
try:
|
|
redis_client = await self._get_redis_client()
|
|
if not redis_client:
|
|
return None
|
|
|
|
cache_key = f"demo_session:{session_id}"
|
|
cached_data = await redis_client.get(cache_key)
|
|
|
|
if cached_data:
|
|
logger.debug("Demo middleware: Cache HIT", session_id=session_id)
|
|
return json.loads(cached_data)
|
|
else:
|
|
logger.debug("Demo middleware: Cache MISS", session_id=session_id)
|
|
return None
|
|
except Exception as e:
|
|
logger.warning(f"Demo middleware: Redis cache read error: {e}")
|
|
return None
|
|
|
|
async def _cache_session(self, session_id: str, session_info: dict, ttl: int = 30):
|
|
"""Cache session info in Redis with TTL"""
|
|
try:
|
|
redis_client = await self._get_redis_client()
|
|
if not redis_client:
|
|
return
|
|
|
|
cache_key = f"demo_session:{session_id}"
|
|
serialized = json.dumps(session_info)
|
|
await redis_client.setex(cache_key, ttl, serialized)
|
|
logger.debug(f"Demo middleware: Cached session {session_id} (TTL: {ttl}s)")
|
|
except Exception as e:
|
|
logger.warning(f"Demo middleware: Redis cache write error: {e}")
|
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
"""Process request through demo middleware"""
|
|
|
|
# Skip demo middleware for demo service endpoints
|
|
demo_service_paths = [
|
|
"/api/v1/demo/accounts",
|
|
"/api/v1/demo/sessions",
|
|
"/api/v1/demo/stats",
|
|
"/api/v1/demo/operations",
|
|
]
|
|
|
|
if any(request.url.path.startswith(path) or request.url.path == path for path in demo_service_paths):
|
|
return await call_next(request)
|
|
|
|
# Extract session ID from header or cookie
|
|
session_id = (
|
|
request.headers.get("X-Demo-Session-Id") or
|
|
request.cookies.get("demo_session_id")
|
|
)
|
|
|
|
logger.info(f"🎭 DemoMiddleware - path: {request.url.path}, session_id: {session_id}")
|
|
|
|
# Extract tenant ID from request
|
|
tenant_id = request.headers.get("X-Tenant-Id")
|
|
|
|
# Check if this is a demo session request
|
|
if session_id:
|
|
try:
|
|
# PERFORMANCE OPTIMIZATION: Check Redis cache first before HTTP call
|
|
session_info = await self._get_cached_session(session_id)
|
|
|
|
if not session_info:
|
|
# Cache miss - fetch from demo service
|
|
logger.debug("Demo middleware: Fetching from demo service", session_id=session_id)
|
|
session_info = await self._get_session_info(session_id)
|
|
|
|
# Cache the result if successful (30s TTL to balance freshness vs performance)
|
|
if session_info:
|
|
await self._cache_session(session_id, session_info, ttl=30)
|
|
|
|
# Accept pending, ready, partial, failed (if data exists), and active (deprecated) statuses
|
|
# Even "failed" sessions can be usable if some services succeeded
|
|
valid_statuses = ["pending", "ready", "partial", "failed", "active"]
|
|
current_status = session_info.get("status") if session_info else None
|
|
|
|
if session_info and current_status in valid_statuses:
|
|
# NOTE: Path transformation for demo-user removed.
|
|
# Frontend now receives the real demo_user_id from session creation
|
|
# and uses it directly in API calls.
|
|
|
|
# Inject virtual tenant ID
|
|
# Use scope state directly to avoid potential state property issues
|
|
request.scope.setdefault("state", {})
|
|
state = request.scope["state"]
|
|
state["tenant_id"] = session_info["virtual_tenant_id"]
|
|
state["is_demo_session"] = True
|
|
state["demo_account_type"] = session_info["demo_account_type"]
|
|
state["demo_session_status"] = current_status # Track status for monitoring
|
|
|
|
# Inject demo user context for auth middleware
|
|
# Uses DEMO_USER_IDS constant defined at module level
|
|
demo_user_id = DEMO_USER_IDS.get(
|
|
session_info.get("demo_account_type", "professional"),
|
|
DEMO_USER_IDS["professional"]
|
|
)
|
|
|
|
# This allows the request to pass through AuthMiddleware
|
|
# NEW: Extract subscription tier from demo account type
|
|
subscription_tier = "enterprise" if session_info.get("demo_account_type") == "enterprise" else "professional"
|
|
|
|
state["user"] = {
|
|
"user_id": demo_user_id, # Use actual demo user UUID
|
|
"email": f"demo-{session_id}@demo.local",
|
|
"tenant_id": session_info["virtual_tenant_id"],
|
|
"role": "owner", # Demo users have owner role
|
|
"is_demo": True,
|
|
"demo_session_id": session_id,
|
|
"demo_account_type": session_info.get("demo_account_type", "professional"),
|
|
"demo_session_status": current_status,
|
|
# NEW: Subscription context (no HTTP call needed!)
|
|
"subscription_tier": subscription_tier,
|
|
"subscription_status": "active",
|
|
"subscription_from_jwt": True # Flag to skip HTTP calls
|
|
}
|
|
|
|
# Update activity
|
|
await self._update_session_activity(session_id)
|
|
|
|
# Check if path is explicitly blocked
|
|
blocked_reason = self._check_blocked_path(request.url.path)
|
|
if blocked_reason:
|
|
return JSONResponse(
|
|
status_code=403,
|
|
content={
|
|
"error": "demo_restriction",
|
|
**blocked_reason,
|
|
"upgrade_url": "/pricing",
|
|
"session_expires_at": session_info.get("expires_at")
|
|
}
|
|
)
|
|
|
|
# Check if operation is allowed
|
|
if not self._is_operation_allowed(request.method, request.url.path):
|
|
return JSONResponse(
|
|
status_code=403,
|
|
content={
|
|
"error": "demo_restriction",
|
|
"message": "Esta operación no está permitida en cuentas demo. "
|
|
"Las sesiones demo se eliminan automáticamente después de 30 minutos. "
|
|
"Suscríbete para obtener acceso completo.",
|
|
"message_en": "This operation is not allowed in demo accounts. "
|
|
"Demo sessions are automatically deleted after 30 minutes. "
|
|
"Subscribe for full access.",
|
|
"upgrade_url": "/pricing",
|
|
"session_expires_at": session_info.get("expires_at")
|
|
}
|
|
)
|
|
else:
|
|
# Session expired, invalid, or in failed/destroyed state
|
|
logger.warning(f"Invalid demo session state", session_id=session_id, status=current_status)
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={
|
|
"error": "session_expired",
|
|
"message": "Tu sesión demo ha expirado. Crea una nueva sesión para continuar.",
|
|
"message_en": "Your demo session has expired. Create a new session to continue.",
|
|
"session_status": current_status
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("Demo middleware error", error=str(e), session_id=session_id, path=request.url.path)
|
|
# On error, return 401 instead of continuing
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={
|
|
"error": "session_error",
|
|
"message": "Error validando sesión demo. Por favor, inténtalo de nuevo.",
|
|
"message_en": "Error validating demo session. Please try again."
|
|
}
|
|
)
|
|
|
|
# Check if this is a demo tenant (base template)
|
|
elif tenant_id in DEMO_TENANT_IDS:
|
|
# Direct access to demo tenant without session - block writes
|
|
request.scope.setdefault("state", {})
|
|
state = request.scope["state"]
|
|
state["is_demo_session"] = True
|
|
state["tenant_id"] = tenant_id
|
|
|
|
if request.method not in ["GET", "HEAD", "OPTIONS"]:
|
|
return JSONResponse(
|
|
status_code=403,
|
|
content={
|
|
"error": "demo_restriction",
|
|
"message": "Acceso directo al tenant demo no permitido. Crea una sesión demo.",
|
|
"message_en": "Direct access to demo tenant not allowed. Create a demo session."
|
|
}
|
|
)
|
|
|
|
# Proceed with request
|
|
response = await call_next(request)
|
|
|
|
# Add demo session header to response if demo session
|
|
if hasattr(request.state, "is_demo_session") and request.state.is_demo_session:
|
|
response.headers["X-Demo-Session"] = "true"
|
|
|
|
return response
|
|
|
|
async def _get_session_info(self, session_id: str) -> Optional[dict]:
|
|
"""Get session information from demo service using JWT service token"""
|
|
try:
|
|
# Create JWT service token for gateway-to-demo-session communication
|
|
from shared.auth.jwt_handler import JWTHandler
|
|
from app.core.config import settings
|
|
|
|
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
|
service_token = jwt_handler.create_service_token(service_name="gateway")
|
|
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
response = await client.get(
|
|
f"{self.demo_session_url}/api/v1/demo/sessions/{session_id}",
|
|
headers={"Authorization": f"Bearer {service_token}"}
|
|
)
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
logger.warning("Demo session fetch failed",
|
|
session_id=session_id,
|
|
status_code=response.status_code,
|
|
response_text=response.text[:200] if hasattr(response, 'text') else '')
|
|
return None
|
|
except Exception as e:
|
|
logger.error("Failed to get session info", session_id=session_id, error=str(e))
|
|
return None
|
|
|
|
async def _update_session_activity(self, session_id: str):
|
|
"""Update session activity timestamp"""
|
|
# Note: Activity tracking is handled by the demo service internally
|
|
# No explicit endpoint needed - activity is updated on session access
|
|
pass
|
|
|
|
def _check_blocked_path(self, path: str) -> Optional[dict]:
|
|
"""Check if path is explicitly blocked for demo accounts"""
|
|
for blocked_path in DEMO_BLOCKED_PATHS:
|
|
if blocked_path in path:
|
|
# Determine which category of blocked path
|
|
if "forecast" in blocked_path:
|
|
return DEMO_BLOCKED_PATH_MESSAGE["forecasts"]
|
|
# Can add more categories here in the future
|
|
return {
|
|
"message": "Esta funcionalidad no está disponible para cuentas demo.",
|
|
"message_en": "This functionality is not available for demo accounts."
|
|
}
|
|
return None
|
|
|
|
def _is_operation_allowed(self, method: str, path: str) -> bool:
|
|
"""Check if method + path combination is allowed for demo"""
|
|
|
|
allowed_paths = DEMO_ALLOWED_OPERATIONS.get(method, [])
|
|
|
|
# Check for wildcard
|
|
if "*" in allowed_paths:
|
|
return True
|
|
|
|
# Check for exact match or pattern match
|
|
for allowed_path in allowed_paths:
|
|
if allowed_path.endswith("*"):
|
|
# Pattern match: /api/orders/* matches /api/orders/123
|
|
if path.startswith(allowed_path[:-1]):
|
|
return True
|
|
elif path == allowed_path:
|
|
# Exact match
|
|
return True
|
|
|
|
return False
|