Initial commit - production deployment
This commit is contained in:
0
gateway/app/middleware/__init__.py
Normal file
0
gateway/app/middleware/__init__.py
Normal file
649
gateway/app/middleware/auth.py
Normal file
649
gateway/app/middleware/auth.py
Normal file
@@ -0,0 +1,649 @@
|
||||
# gateway/app/middleware/auth.py
|
||||
"""
|
||||
Enhanced Authentication Middleware for API Gateway with Tenant Access Control
|
||||
FIXED VERSION - Proper JWT verification and token structure handling
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from typing import Optional, Dict, Any
|
||||
import httpx
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# JWT handler for local token validation - using SAME configuration as auth service
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
# Routes that don't require authentication
|
||||
PUBLIC_ROUTES = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/refresh",
|
||||
"/api/v1/auth/verify",
|
||||
"/api/v1/auth/start-registration", # Registration step 1 - SetupIntent creation
|
||||
"/api/v1/auth/complete-registration", # Registration step 2 - Completion after 3DS
|
||||
"/api/v1/registration/payment-setup", # New registration payment setup endpoint
|
||||
"/api/v1/registration/complete", # New registration completion endpoint
|
||||
"/api/v1/registration/state/", # Registration state check
|
||||
"/api/v1/auth/verify-email", # Email verification
|
||||
"/api/v1/auth/password/reset-request", # Password reset request - no auth required
|
||||
"/api/v1/auth/password/reset", # Password reset with token - no auth required
|
||||
"/api/v1/nominatim/search",
|
||||
"/api/v1/plans",
|
||||
"/api/v1/demo/accounts",
|
||||
"/api/v1/demo/sessions",
|
||||
"/api/v1/webhooks/stripe", # Stripe webhook endpoint - bypasses auth for signature verification
|
||||
"/api/v1/webhooks/generic", # Generic webhook endpoint
|
||||
"/api/v1/telemetry/v1/traces", # Frontend telemetry traces - no auth for performance
|
||||
"/api/v1/telemetry/v1/metrics", # Frontend telemetry metrics - no auth for performance
|
||||
"/api/v1/telemetry/health" # Telemetry health check
|
||||
]
|
||||
|
||||
# Routes accessible with demo session (no JWT required, just demo session header)
|
||||
DEMO_ACCESSIBLE_ROUTES = [
|
||||
"/api/v1/tenants/", # All tenant endpoints accessible in demo mode
|
||||
]
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Enhanced Authentication Middleware with Tenant Access Control
|
||||
"""
|
||||
|
||||
def __init__(self, app, redis_client=None):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client # For caching and rate limiting
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with enhanced authentication and tenant access control"""
|
||||
|
||||
# Skip authentication for OPTIONS requests (CORS preflight)
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
|
||||
# SECURITY: Remove any incoming sensitive headers using HeaderManager
|
||||
header_manager.sanitize_incoming_headers(request)
|
||||
|
||||
# Skip authentication for public routes
|
||||
if self._is_public_route(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# ✅ Check if demo middleware already set user context OR check query param for SSE
|
||||
demo_session_header = request.headers.get("X-Demo-Session-Id")
|
||||
demo_session_query = request.query_params.get("demo_session_id") # For SSE endpoint
|
||||
logger.info(f"Auth check - path: {request.url.path}, demo_header: {demo_session_header}, demo_query: {demo_session_query}, has_demo_state: {hasattr(request.state, 'is_demo_session')}")
|
||||
|
||||
# For SSE endpoint with demo_session_id in query params, validate it here
|
||||
if request.url.path == "/api/v1/events" and demo_session_query and not hasattr(request.state, "is_demo_session"):
|
||||
logger.info(f"SSE endpoint with demo_session_id query param: {demo_session_query}")
|
||||
# Validate demo session via demo-session service using JWT service token
|
||||
import httpx
|
||||
try:
|
||||
# Create service token for gateway-to-demo-session communication
|
||||
service_token = jwt_handler.create_service_token(service_name="gateway")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"http://demo-session-service:8000/api/v1/demo/sessions/{demo_session_query}",
|
||||
headers={"Authorization": f"Bearer {service_token}"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
session_data = response.json()
|
||||
# Set demo session context
|
||||
request.state.is_demo_session = True
|
||||
request.state.user = {
|
||||
"user_id": f"demo-user-{demo_session_query}",
|
||||
"email": f"demo-{demo_session_query}@demo.local",
|
||||
"tenant_id": session_data.get("virtual_tenant_id"),
|
||||
"demo_session_id": demo_session_query,
|
||||
}
|
||||
request.state.tenant_id = session_data.get("virtual_tenant_id")
|
||||
logger.info(f"✅ Demo session validated for SSE: {demo_session_query}")
|
||||
else:
|
||||
logger.warning(f"Invalid demo session for SSE: {demo_session_query}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid demo session"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate demo session for SSE: {e}")
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"detail": "Demo session service unavailable"}
|
||||
)
|
||||
|
||||
if hasattr(request.state, "is_demo_session") and request.state.is_demo_session:
|
||||
if hasattr(request.state, "user") and request.state.user:
|
||||
logger.info(f"✅ Demo session authenticated for route: {request.url.path}")
|
||||
# Demo middleware already validated and set user context
|
||||
# But we still need to inject context headers for downstream services
|
||||
user_context = request.state.user
|
||||
tenant_id = user_context.get("tenant_id") or getattr(request.state, "tenant_id", None)
|
||||
|
||||
# For demo sessions, get the actual subscription tier from the tenant service
|
||||
# instead of always defaulting to enterprise
|
||||
if not user_context.get("subscription_tier"):
|
||||
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
|
||||
if subscription_tier:
|
||||
user_context["subscription_tier"] = subscription_tier
|
||||
else:
|
||||
# Fallback to enterprise for demo if no tier is found
|
||||
user_context["subscription_tier"] = "enterprise"
|
||||
|
||||
logger.debug(f"Demo session subscription tier set to {user_context['subscription_tier']}", tenant_id=tenant_id)
|
||||
|
||||
await self._inject_context_headers(request, user_context, tenant_id)
|
||||
return await call_next(request)
|
||||
|
||||
# ✅ STEP 1: Extract and validate JWT token
|
||||
token = self._extract_token(request)
|
||||
if not token:
|
||||
logger.warning(f"❌ Missing token for protected route: {request.url.path}, demo_header: {demo_session_header}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Authentication required"}
|
||||
)
|
||||
|
||||
# ✅ STEP 2: Verify token and get user context
|
||||
user_context = await self._verify_token(token, request)
|
||||
if not user_context:
|
||||
logger.warning(f"Invalid token for route: {request.url.path}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "User not authenticated"}
|
||||
)
|
||||
|
||||
# ✅ STEP 3: Extract tenant context from URL using shared utility
|
||||
tenant_id = extract_tenant_id_from_path(request.url.path)
|
||||
|
||||
# ✅ STEP 4: Verify tenant access if this is a tenant-scoped route
|
||||
if tenant_id and is_tenant_scoped_path(request.url.path):
|
||||
# Skip tenant access verification for service tokens (services have admin access)
|
||||
if user_context.get("type") != "service":
|
||||
# Use TenantAccessManager for gateway-level verification with caching
|
||||
if self.redis_client and tenant_access_manager.redis_client is None:
|
||||
tenant_access_manager.redis_client = self.redis_client
|
||||
|
||||
has_access = await tenant_access_manager.verify_basic_tenant_access(
|
||||
user_context["user_id"],
|
||||
tenant_id
|
||||
)
|
||||
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": f"Access denied to tenant {tenant_id}"}
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Service token granted access to tenant {tenant_id}",
|
||||
service=user_context.get("service"))
|
||||
|
||||
# Get tenant subscription tier and inject into user context
|
||||
# NEW: Use JWT data if available, skip HTTP call
|
||||
if user_context.get("subscription_from_jwt"):
|
||||
subscription_tier = user_context.get("subscription_tier")
|
||||
logger.debug("Using subscription tier from JWT", tier=subscription_tier)
|
||||
else:
|
||||
# Only for old tokens - remove after full rollout
|
||||
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
|
||||
|
||||
if subscription_tier:
|
||||
user_context["subscription_tier"] = subscription_tier
|
||||
|
||||
# Check hierarchical access to determine access type and permissions
|
||||
hierarchical_access = await tenant_access_manager.verify_hierarchical_access(
|
||||
user_context["user_id"],
|
||||
tenant_id
|
||||
)
|
||||
|
||||
# Set tenant context in request state
|
||||
request.state.tenant_id = tenant_id
|
||||
request.state.tenant_verified = True
|
||||
request.state.tenant_access_type = hierarchical_access.get("access_type", "direct")
|
||||
request.state.can_view_children = hierarchical_access.get("can_view_children", False)
|
||||
|
||||
logger.debug(f"Tenant access verified",
|
||||
user_id=user_context["user_id"],
|
||||
tenant_id=tenant_id,
|
||||
subscription_tier=subscription_tier,
|
||||
access_type=hierarchical_access.get("access_type"),
|
||||
can_view_children=hierarchical_access.get("can_view_children"),
|
||||
path=request.url.path)
|
||||
|
||||
# ✅ STEP 5: Inject user context into request
|
||||
request.state.user = user_context
|
||||
request.state.authenticated = True
|
||||
|
||||
# ✅ STEP 6: Add context headers for downstream services
|
||||
await self._inject_context_headers(request, user_context, tenant_id)
|
||||
|
||||
logger.debug(f"Authenticated request",
|
||||
user_email=user_context['email'],
|
||||
tenant_id=tenant_id,
|
||||
path=request.url.path)
|
||||
|
||||
# Process the request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add token expiry warning header if token is near expiry
|
||||
if hasattr(request.state, 'token_near_expiry') and request.state.token_near_expiry:
|
||||
response.headers["X-Token-Refresh-Suggested"] = "true"
|
||||
|
||||
return response
|
||||
|
||||
def _is_public_route(self, path: str) -> bool:
|
||||
"""Check if route requires authentication"""
|
||||
return any(path.startswith(route) for route in PUBLIC_ROUTES)
|
||||
|
||||
def _extract_token(self, request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract JWT token from Authorization header or query params for SSE.
|
||||
|
||||
For SSE endpoints (/api/v1/events), browsers' EventSource API cannot send
|
||||
custom headers, so we must accept token as query parameter.
|
||||
For all other routes, token must be in Authorization header (more secure).
|
||||
|
||||
Security note: Query param tokens are logged. Use short expiry and filter logs.
|
||||
"""
|
||||
# SSE endpoint exception: token in query param (EventSource API limitation)
|
||||
if request.url.path == "/api/v1/events":
|
||||
token = request.query_params.get("token")
|
||||
if token:
|
||||
logger.debug("Token extracted from query param for SSE endpoint")
|
||||
return token
|
||||
logger.warning("SSE request missing token in query param")
|
||||
return None
|
||||
|
||||
# Standard authentication: Authorization header for all other routes
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header.split(" ")[1]
|
||||
|
||||
return None
|
||||
|
||||
async def _verify_token(self, token: str, request: Request = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verify JWT token with improved fallback strategy
|
||||
FIXED: Better error handling and token structure validation
|
||||
"""
|
||||
|
||||
# Strategy 1: Try local JWT validation first (fast)
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if payload and self._validate_token_payload(payload):
|
||||
logger.debug("Token validated locally")
|
||||
|
||||
# NEW: Check token freshness for subscription changes (async)
|
||||
if payload.get("tenant_id") and request:
|
||||
try:
|
||||
is_fresh = await self._verify_token_freshness(payload, payload["tenant_id"])
|
||||
if not is_fresh:
|
||||
logger.warning("Stale token detected - subscription changed since token was issued",
|
||||
user_id=payload.get("user_id"),
|
||||
tenant_id=payload.get("tenant_id"))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token is stale - subscription has changed"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Token freshness check failed, allowing token", error=str(e))
|
||||
# Allow token if check fails (fail open for availability)
|
||||
|
||||
# Check if token is near expiry and set flag for response header
|
||||
if request:
|
||||
import time
|
||||
exp_time = payload.get("exp", 0)
|
||||
current_time = time.time()
|
||||
time_until_expiry = exp_time - current_time
|
||||
|
||||
if time_until_expiry < 300: # 5 minutes
|
||||
request.state.token_near_expiry = True
|
||||
|
||||
# Convert JWT payload to user context format
|
||||
return self._jwt_payload_to_user_context(payload)
|
||||
except Exception as e:
|
||||
logger.debug(f"Local token validation failed: {e}")
|
||||
|
||||
# Strategy 2: Check cache for recently validated tokens
|
||||
if self.redis_client:
|
||||
try:
|
||||
cached_user = await self._get_cached_user(token)
|
||||
if cached_user:
|
||||
logger.debug("Token found in cache")
|
||||
return cached_user
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache lookup failed: {e}")
|
||||
|
||||
# Strategy 3: Verify with auth service (authoritative)
|
||||
try:
|
||||
user_context = await self._verify_with_auth_service(token)
|
||||
if user_context:
|
||||
# Cache successful validation
|
||||
if self.redis_client:
|
||||
await self._cache_user(token, user_context)
|
||||
logger.debug("Token validated by auth service")
|
||||
return user_context
|
||||
except Exception as e:
|
||||
logger.error(f"Auth service validation failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _validate_token_payload(self, payload: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate JWT payload has required fields
|
||||
FIXED: Updated to match actual token structure from auth service
|
||||
"""
|
||||
required_fields = ["user_id", "email", "exp", "type"]
|
||||
missing_fields = [field for field in required_fields if field not in payload]
|
||||
|
||||
if missing_fields:
|
||||
logger.warning(f"Token payload missing fields: {missing_fields}")
|
||||
return False
|
||||
|
||||
# Validate token type
|
||||
token_type = payload.get("type")
|
||||
if token_type not in ["access", "service"]:
|
||||
logger.warning(f"Invalid token type: {payload.get('type')}")
|
||||
return False
|
||||
|
||||
# Check if token is near expiry (within 5 minutes) and log warning
|
||||
import time
|
||||
exp_time = payload.get("exp", 0)
|
||||
current_time = time.time()
|
||||
time_until_expiry = exp_time - current_time
|
||||
|
||||
if time_until_expiry < 300: # 5 minutes
|
||||
logger.warning(f"Token expires in {int(time_until_expiry)} seconds for user {payload.get('email')}")
|
||||
|
||||
# NEW: Check token freshness for subscription changes
|
||||
if payload.get("tenant_id"):
|
||||
try:
|
||||
# Note: We can't await here because this is a sync function
|
||||
# Token freshness will be checked in the async dispatch method
|
||||
# For now, just log that we would check freshness
|
||||
logger.debug("Token freshness check would be performed in async context",
|
||||
tenant_id=payload.get("tenant_id"))
|
||||
except Exception as e:
|
||||
logger.warning("Token freshness check setup failed", error=str(e))
|
||||
|
||||
# FIX: Validate service tokens with tenant context for tenant-scoped routes
|
||||
if token_type == "service" and payload.get("tenant_id"):
|
||||
# Service tokens with tenant context are valid for tenant-scoped operations
|
||||
logger.debug("Service token with tenant context validated",
|
||||
service=payload.get("service"), tenant_id=payload.get("tenant_id"))
|
||||
|
||||
return True
|
||||
|
||||
def _validate_jwt_integrity(self, payload: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate JWT payload integrity beyond signature verification.
|
||||
Prevents edge cases where payload might be malformed.
|
||||
"""
|
||||
# Required fields must exist
|
||||
required_fields = ["user_id", "email", "exp", "iat", "iss"]
|
||||
if not all(field in payload for field in required_fields):
|
||||
logger.warning("JWT missing required fields", missing=[f for f in required_fields if f not in payload])
|
||||
return False
|
||||
|
||||
# Issuer must be our auth service
|
||||
if payload.get("iss") != "bakery-auth":
|
||||
logger.warning("JWT has invalid issuer", issuer=payload.get("iss"))
|
||||
return False
|
||||
|
||||
# Token type must be valid
|
||||
if payload.get("type") not in ["access", "service"]:
|
||||
logger.warning("JWT has invalid type", token_type=payload.get("type"))
|
||||
return False
|
||||
|
||||
# Subscription tier must be valid if present
|
||||
valid_tiers = ["starter", "professional", "enterprise"]
|
||||
if payload.get("subscription"):
|
||||
tier = payload["subscription"].get("tier", "").lower()
|
||||
if tier and tier not in valid_tiers:
|
||||
logger.warning("JWT has invalid subscription tier", tier=tier)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _verify_token_freshness(self, payload: Dict[str, Any], tenant_id: str) -> bool:
|
||||
"""
|
||||
Verify token was issued after the last subscription change.
|
||||
Prevents use of stale tokens with old subscription data.
|
||||
"""
|
||||
if not self.redis_client:
|
||||
return True # Skip check if no Redis
|
||||
|
||||
try:
|
||||
subscription_changed_at = await self.redis_client.get(
|
||||
f"tenant:{tenant_id}:subscription_changed_at"
|
||||
)
|
||||
|
||||
if subscription_changed_at:
|
||||
changed_timestamp = float(subscription_changed_at)
|
||||
token_issued_at = payload.get("iat", 0)
|
||||
|
||||
if token_issued_at < changed_timestamp:
|
||||
logger.warning(
|
||||
"Token issued before subscription change",
|
||||
token_iat=token_issued_at,
|
||||
subscription_changed=changed_timestamp,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return False # Token is stale
|
||||
except Exception as e:
|
||||
logger.warning("Failed to check token freshness", error=str(e))
|
||||
|
||||
return True
|
||||
|
||||
def _jwt_payload_to_user_context(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert JWT payload to user context format
|
||||
FIXED: Proper mapping between JWT structure and user context
|
||||
"""
|
||||
# NEW: Validate JWT integrity before processing
|
||||
if not self._validate_jwt_integrity(payload):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid JWT payload"
|
||||
)
|
||||
|
||||
base_context = {
|
||||
"user_id": payload["user_id"],
|
||||
"email": payload["email"],
|
||||
"exp": payload["exp"],
|
||||
"valid": True,
|
||||
"role": payload.get("role", "user"),
|
||||
}
|
||||
|
||||
# NEW: Extract subscription from JWT
|
||||
if payload.get("tenant_id"):
|
||||
base_context["tenant_id"] = payload["tenant_id"]
|
||||
base_context["tenant_role"] = payload.get("tenant_role", "member")
|
||||
|
||||
if payload.get("subscription"):
|
||||
sub = payload["subscription"]
|
||||
base_context["subscription_tier"] = sub.get("tier", "starter")
|
||||
base_context["subscription_status"] = sub.get("status", "active")
|
||||
base_context["subscription_from_jwt"] = True # Flag to skip HTTP
|
||||
|
||||
if payload.get("tenant_access"):
|
||||
base_context["tenant_access"] = payload["tenant_access"]
|
||||
|
||||
if payload.get("service"):
|
||||
service_name = payload["service"]
|
||||
base_context["service"] = service_name
|
||||
base_context["type"] = "service"
|
||||
base_context["role"] = "admin"
|
||||
base_context["user_id"] = f"{service_name}-service"
|
||||
base_context["email"] = f"{service_name}-service@internal"
|
||||
|
||||
# FIX: Service tokens with tenant context should use that tenant_id
|
||||
if payload.get("tenant_id"):
|
||||
base_context["tenant_id"] = payload["tenant_id"]
|
||||
logger.debug(f"Service authentication with tenant context: {service_name}, tenant_id: {payload['tenant_id']}")
|
||||
else:
|
||||
logger.debug(f"Service authentication: {service_name}")
|
||||
|
||||
return base_context
|
||||
|
||||
async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verify token with auth service
|
||||
FIXED: Improved error handling and response parsing
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.post(
|
||||
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
auth_response = response.json()
|
||||
|
||||
# Validate auth service response structure
|
||||
if auth_response.get("valid") and auth_response.get("user_id"):
|
||||
return {
|
||||
"user_id": auth_response["user_id"],
|
||||
"email": auth_response["email"],
|
||||
"exp": auth_response.get("exp"),
|
||||
"valid": True
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Auth service returned invalid response: {auth_response}")
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"Auth service returned {response.status_code}: {response.text}")
|
||||
return None
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Auth service timeout during token verification")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Auth service error: {e}")
|
||||
return None
|
||||
|
||||
async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get user context from cache
|
||||
FIXED: Better error handling and JSON parsing
|
||||
"""
|
||||
if not self.redis_client:
|
||||
return None
|
||||
|
||||
cache_key = f"auth:token:{hash(token) % 1000000}" # Use modulo for shorter keys
|
||||
try:
|
||||
cached_data = await self.redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
if isinstance(cached_data, bytes):
|
||||
cached_data = cached_data.decode()
|
||||
return json.loads(cached_data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse cached user data: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache lookup error: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _cache_user(self, token: str, user_context: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Cache user context
|
||||
FIXED: Better error handling and expiration
|
||||
"""
|
||||
if not self.redis_client:
|
||||
return
|
||||
|
||||
cache_key = f"auth:token:{hash(token) % 1000000}"
|
||||
try:
|
||||
# Cache for 5 minutes (shorter than token expiry)
|
||||
await self.redis_client.setex(
|
||||
cache_key,
|
||||
300, # 5 minutes
|
||||
json.dumps(user_context)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache user context: {e}")
|
||||
|
||||
async def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
|
||||
"""
|
||||
Inject user and tenant context headers for downstream services using unified HeaderManager
|
||||
"""
|
||||
# Use unified HeaderManager for consistent header injection
|
||||
injected_headers = header_manager.inject_context_headers(request, user_context, tenant_id)
|
||||
|
||||
# Add hierarchical access headers if tenant context exists
|
||||
if tenant_id:
|
||||
# If this is hierarchical access, include parent tenant ID
|
||||
# Get parent tenant ID from the auth service if available
|
||||
try:
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/hierarchy",
|
||||
headers={"Authorization": request.headers.get("Authorization", "")}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
hierarchy_data = response.json()
|
||||
parent_tenant_id = hierarchy_data.get("parent_tenant_id")
|
||||
if parent_tenant_id:
|
||||
# Add parent tenant ID using HeaderManager for consistency
|
||||
header_name = header_manager.STANDARD_HEADERS['parent_tenant_id']
|
||||
header_value = str(parent_tenant_id)
|
||||
header_manager.add_header_for_middleware(request, header_name, header_value)
|
||||
logger.info(f"Added parent tenant ID header: {parent_tenant_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get parent tenant ID: {e}")
|
||||
pass
|
||||
|
||||
return injected_headers
|
||||
|
||||
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
|
||||
"""
|
||||
Get tenant subscription tier using fast cached endpoint
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
request: FastAPI request for headers
|
||||
|
||||
Returns:
|
||||
Subscription tier string or None
|
||||
"""
|
||||
try:
|
||||
# Use fast cached subscription tier endpoint (has its own Redis caching)
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
headers = {"Authorization": request.headers.get("Authorization", "")}
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
tier_data = response.json()
|
||||
subscription_tier = tier_data.get("tier", "starter")
|
||||
|
||||
logger.debug("Subscription tier from cached endpoint",
|
||||
tenant_id=tenant_id,
|
||||
tier=subscription_tier,
|
||||
cached=tier_data.get("cached", False))
|
||||
return subscription_tier
|
||||
else:
|
||||
logger.warning(f"Failed to get tenant subscription tier: {response.status_code}")
|
||||
return "starter" # Default to starter
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tenant subscription tier: {e}")
|
||||
return "starter" # Default to starter on error
|
||||
384
gateway/app/middleware/demo_middleware.py
Normal file
384
gateway/app/middleware/demo_middleware.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Demo Session Middleware
|
||||
Handles demo account restrictions and virtual tenant injection
|
||||
"""
|
||||
|
||||
from fastapi import Request, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from typing import Optional
|
||||
import uuid
|
||||
import httpx
|
||||
import structlog
|
||||
import json
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Fixed Demo Tenant IDs (these are the template tenants that will be cloned)
|
||||
# Professional demo (merged from San Pablo + La Espiga)
|
||||
DEMO_TENANT_PROFESSIONAL = uuid.UUID("a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6")
|
||||
|
||||
# Enterprise chain demo (parent + 3 children)
|
||||
DEMO_TENANT_ENTERPRISE_CHAIN = uuid.UUID("c3d4e5f6-a7b8-49c0-d1e2-f3a4b5c6d7e8")
|
||||
DEMO_TENANT_CHILD_1 = uuid.UUID("d4e5f6a7-b8c9-40d1-e2f3-a4b5c6d7e8f9")
|
||||
DEMO_TENANT_CHILD_2 = uuid.UUID("e5f6a7b8-c9d0-41e2-f3a4-b5c6d7e8f9a0")
|
||||
DEMO_TENANT_CHILD_3 = uuid.UUID("f6a7b8c9-d0e1-42f3-a4b5-c6d7e8f9a0b1")
|
||||
|
||||
# Demo tenant IDs (base templates)
|
||||
DEMO_TENANT_IDS = {
|
||||
str(DEMO_TENANT_PROFESSIONAL), # Professional demo tenant
|
||||
str(DEMO_TENANT_ENTERPRISE_CHAIN), # Enterprise chain parent
|
||||
str(DEMO_TENANT_CHILD_1), # Enterprise chain child 1
|
||||
str(DEMO_TENANT_CHILD_2), # Enterprise chain child 2
|
||||
str(DEMO_TENANT_CHILD_3), # Enterprise chain child 3
|
||||
}
|
||||
|
||||
# Demo user IDs - Maps demo account type to actual user UUIDs from fixture files
|
||||
# These IDs are the owner IDs from the respective 01-tenant.json files
|
||||
DEMO_USER_IDS = {
|
||||
"professional": "c1a2b3c4-d5e6-47a8-b9c0-d1e2f3a4b5c6", # María García López (professional/01-tenant.json -> owner.id)
|
||||
"enterprise": "d2e3f4a5-b6c7-48d9-e0f1-a2b3c4d5e6f7" # Director (enterprise/parent/01-tenant.json -> owner.id)
|
||||
}
|
||||
|
||||
# Allowed operations for demo accounts (limited write)
|
||||
DEMO_ALLOWED_OPERATIONS = {
|
||||
# Read operations - all allowed
|
||||
"GET": ["*"],
|
||||
|
||||
# Limited write operations for realistic testing
|
||||
"POST": [
|
||||
"/api/v1/pos/sales",
|
||||
"/api/v1/pos/sessions",
|
||||
"/api/v1/orders",
|
||||
"/api/v1/inventory/adjustments",
|
||||
"/api/v1/sales",
|
||||
"/api/v1/production/batches",
|
||||
"/api/v1/tenants/batch/sales-summary",
|
||||
"/api/v1/tenants/batch/production-summary",
|
||||
"/api/v1/auth/me/onboarding/complete", # Allow completing onboarding (no-op for demos)
|
||||
"/api/v1/tenants/*/notifications/send", # Allow notifications (ML insights, alerts, etc.)
|
||||
# Note: Forecast generation is explicitly blocked (see DEMO_BLOCKED_PATHS)
|
||||
],
|
||||
|
||||
"PUT": [
|
||||
"/api/v1/pos/sales/*",
|
||||
"/api/v1/orders/*",
|
||||
"/api/v1/inventory/stock/*",
|
||||
"/api/v1/auth/me/onboarding/step", # Allow onboarding step updates (no-op for demos)
|
||||
],
|
||||
|
||||
# Blocked operations
|
||||
"DELETE": [], # No deletes allowed
|
||||
"PATCH": [], # No patches allowed
|
||||
}
|
||||
|
||||
# Explicitly blocked paths for demo accounts (even if method would be allowed)
|
||||
# These require trained AI models which demo tenants don't have
|
||||
DEMO_BLOCKED_PATHS = [
|
||||
"/api/v1/forecasts/single",
|
||||
"/api/v1/forecasts/multi-day",
|
||||
"/api/v1/forecasts/batch",
|
||||
]
|
||||
|
||||
DEMO_BLOCKED_PATH_MESSAGE = {
|
||||
"forecasts": {
|
||||
"message": "La generación de pronósticos no está disponible para cuentas demo. "
|
||||
"Las cuentas demo no tienen modelos de IA entrenados.",
|
||||
"message_en": "Forecast generation is not available for demo accounts. "
|
||||
"Demo accounts do not have trained AI models.",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class DemoMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to handle demo session logic with Redis caching"""
|
||||
|
||||
def __init__(self, app, demo_session_url: str = "http://demo-session-service:8000"):
|
||||
super().__init__(app)
|
||||
self.demo_session_url = demo_session_url
|
||||
self._redis_client = None
|
||||
|
||||
async def _get_redis_client(self):
|
||||
"""Get or lazily initialize Redis client"""
|
||||
if self._redis_client is None:
|
||||
try:
|
||||
from shared.redis_utils import get_redis_client
|
||||
self._redis_client = await get_redis_client()
|
||||
logger.debug("Demo middleware: Redis client initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Demo middleware: Failed to get Redis client: {e}. Caching disabled.")
|
||||
self._redis_client = False # Sentinel value to avoid retrying
|
||||
|
||||
return self._redis_client if self._redis_client is not False else None
|
||||
|
||||
async def _get_cached_session(self, session_id: str) -> Optional[dict]:
|
||||
"""Get session info from Redis cache"""
|
||||
try:
|
||||
redis_client = await self._get_redis_client()
|
||||
if not redis_client:
|
||||
return None
|
||||
|
||||
cache_key = f"demo_session:{session_id}"
|
||||
cached_data = await redis_client.get(cache_key)
|
||||
|
||||
if cached_data:
|
||||
logger.debug("Demo middleware: Cache HIT", session_id=session_id)
|
||||
return json.loads(cached_data)
|
||||
else:
|
||||
logger.debug("Demo middleware: Cache MISS", session_id=session_id)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Demo middleware: Redis cache read error: {e}")
|
||||
return None
|
||||
|
||||
async def _cache_session(self, session_id: str, session_info: dict, ttl: int = 30):
|
||||
"""Cache session info in Redis with TTL"""
|
||||
try:
|
||||
redis_client = await self._get_redis_client()
|
||||
if not redis_client:
|
||||
return
|
||||
|
||||
cache_key = f"demo_session:{session_id}"
|
||||
serialized = json.dumps(session_info)
|
||||
await redis_client.setex(cache_key, ttl, serialized)
|
||||
logger.debug(f"Demo middleware: Cached session {session_id} (TTL: {ttl}s)")
|
||||
except Exception as e:
|
||||
logger.warning(f"Demo middleware: Redis cache write error: {e}")
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request through demo middleware"""
|
||||
|
||||
# Skip demo middleware for demo service endpoints
|
||||
demo_service_paths = [
|
||||
"/api/v1/demo/accounts",
|
||||
"/api/v1/demo/sessions",
|
||||
"/api/v1/demo/stats",
|
||||
"/api/v1/demo/operations",
|
||||
]
|
||||
|
||||
if any(request.url.path.startswith(path) or request.url.path == path for path in demo_service_paths):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract session ID from header or cookie
|
||||
session_id = (
|
||||
request.headers.get("X-Demo-Session-Id") or
|
||||
request.cookies.get("demo_session_id")
|
||||
)
|
||||
|
||||
logger.info(f"🎭 DemoMiddleware - path: {request.url.path}, session_id: {session_id}")
|
||||
|
||||
# Extract tenant ID from request
|
||||
tenant_id = request.headers.get("X-Tenant-Id")
|
||||
|
||||
# Check if this is a demo session request
|
||||
if session_id:
|
||||
try:
|
||||
# PERFORMANCE OPTIMIZATION: Check Redis cache first before HTTP call
|
||||
session_info = await self._get_cached_session(session_id)
|
||||
|
||||
if not session_info:
|
||||
# Cache miss - fetch from demo service
|
||||
logger.debug("Demo middleware: Fetching from demo service", session_id=session_id)
|
||||
session_info = await self._get_session_info(session_id)
|
||||
|
||||
# Cache the result if successful (30s TTL to balance freshness vs performance)
|
||||
if session_info:
|
||||
await self._cache_session(session_id, session_info, ttl=30)
|
||||
|
||||
# Accept pending, ready, partial, failed (if data exists), and active (deprecated) statuses
|
||||
# Even "failed" sessions can be usable if some services succeeded
|
||||
valid_statuses = ["pending", "ready", "partial", "failed", "active"]
|
||||
current_status = session_info.get("status") if session_info else None
|
||||
|
||||
if session_info and current_status in valid_statuses:
|
||||
# NOTE: Path transformation for demo-user removed.
|
||||
# Frontend now receives the real demo_user_id from session creation
|
||||
# and uses it directly in API calls.
|
||||
|
||||
# Inject virtual tenant ID
|
||||
# Use scope state directly to avoid potential state property issues
|
||||
request.scope.setdefault("state", {})
|
||||
state = request.scope["state"]
|
||||
state["tenant_id"] = session_info["virtual_tenant_id"]
|
||||
state["is_demo_session"] = True
|
||||
state["demo_account_type"] = session_info["demo_account_type"]
|
||||
state["demo_session_status"] = current_status # Track status for monitoring
|
||||
|
||||
# Inject demo user context for auth middleware
|
||||
# Uses DEMO_USER_IDS constant defined at module level
|
||||
demo_user_id = DEMO_USER_IDS.get(
|
||||
session_info.get("demo_account_type", "professional"),
|
||||
DEMO_USER_IDS["professional"]
|
||||
)
|
||||
|
||||
# This allows the request to pass through AuthMiddleware
|
||||
# NEW: Extract subscription tier from demo account type
|
||||
subscription_tier = "enterprise" if session_info.get("demo_account_type") == "enterprise" else "professional"
|
||||
|
||||
state["user"] = {
|
||||
"user_id": demo_user_id, # Use actual demo user UUID
|
||||
"email": f"demo-{session_id}@demo.local",
|
||||
"tenant_id": session_info["virtual_tenant_id"],
|
||||
"role": "owner", # Demo users have owner role
|
||||
"is_demo": True,
|
||||
"demo_session_id": session_id,
|
||||
"demo_account_type": session_info.get("demo_account_type", "professional"),
|
||||
"demo_session_status": current_status,
|
||||
# NEW: Subscription context (no HTTP call needed!)
|
||||
"subscription_tier": subscription_tier,
|
||||
"subscription_status": "active",
|
||||
"subscription_from_jwt": True # Flag to skip HTTP calls
|
||||
}
|
||||
|
||||
# Update activity
|
||||
await self._update_session_activity(session_id)
|
||||
|
||||
# Check if path is explicitly blocked
|
||||
blocked_reason = self._check_blocked_path(request.url.path)
|
||||
if blocked_reason:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"error": "demo_restriction",
|
||||
**blocked_reason,
|
||||
"upgrade_url": "/pricing",
|
||||
"session_expires_at": session_info.get("expires_at")
|
||||
}
|
||||
)
|
||||
|
||||
# Check if operation is allowed
|
||||
if not self._is_operation_allowed(request.method, request.url.path):
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"error": "demo_restriction",
|
||||
"message": "Esta operación no está permitida en cuentas demo. "
|
||||
"Las sesiones demo se eliminan automáticamente después de 30 minutos. "
|
||||
"Suscríbete para obtener acceso completo.",
|
||||
"message_en": "This operation is not allowed in demo accounts. "
|
||||
"Demo sessions are automatically deleted after 30 minutes. "
|
||||
"Subscribe for full access.",
|
||||
"upgrade_url": "/pricing",
|
||||
"session_expires_at": session_info.get("expires_at")
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Session expired, invalid, or in failed/destroyed state
|
||||
logger.warning(f"Invalid demo session state", session_id=session_id, status=current_status)
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": "session_expired",
|
||||
"message": "Tu sesión demo ha expirado. Crea una nueva sesión para continuar.",
|
||||
"message_en": "Your demo session has expired. Create a new session to continue.",
|
||||
"session_status": current_status
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Demo middleware error", error=str(e), session_id=session_id, path=request.url.path)
|
||||
# On error, return 401 instead of continuing
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": "session_error",
|
||||
"message": "Error validando sesión demo. Por favor, inténtalo de nuevo.",
|
||||
"message_en": "Error validating demo session. Please try again."
|
||||
}
|
||||
)
|
||||
|
||||
# Check if this is a demo tenant (base template)
|
||||
elif tenant_id in DEMO_TENANT_IDS:
|
||||
# Direct access to demo tenant without session - block writes
|
||||
request.scope.setdefault("state", {})
|
||||
state = request.scope["state"]
|
||||
state["is_demo_session"] = True
|
||||
state["tenant_id"] = tenant_id
|
||||
|
||||
if request.method not in ["GET", "HEAD", "OPTIONS"]:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"error": "demo_restriction",
|
||||
"message": "Acceso directo al tenant demo no permitido. Crea una sesión demo.",
|
||||
"message_en": "Direct access to demo tenant not allowed. Create a demo session."
|
||||
}
|
||||
)
|
||||
|
||||
# Proceed with request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add demo session header to response if demo session
|
||||
if hasattr(request.state, "is_demo_session") and request.state.is_demo_session:
|
||||
response.headers["X-Demo-Session"] = "true"
|
||||
|
||||
return response
|
||||
|
||||
async def _get_session_info(self, session_id: str) -> Optional[dict]:
|
||||
"""Get session information from demo service using JWT service token"""
|
||||
try:
|
||||
# Create JWT service token for gateway-to-demo-session communication
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from app.core.config import settings
|
||||
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
service_token = jwt_handler.create_service_token(service_name="gateway")
|
||||
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.demo_session_url}/api/v1/demo/sessions/{session_id}",
|
||||
headers={"Authorization": f"Bearer {service_token}"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.warning("Demo session fetch failed",
|
||||
session_id=session_id,
|
||||
status_code=response.status_code,
|
||||
response_text=response.text[:200] if hasattr(response, 'text') else '')
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get session info", session_id=session_id, error=str(e))
|
||||
return None
|
||||
|
||||
async def _update_session_activity(self, session_id: str):
|
||||
"""Update session activity timestamp"""
|
||||
# Note: Activity tracking is handled by the demo service internally
|
||||
# No explicit endpoint needed - activity is updated on session access
|
||||
pass
|
||||
|
||||
def _check_blocked_path(self, path: str) -> Optional[dict]:
|
||||
"""Check if path is explicitly blocked for demo accounts"""
|
||||
for blocked_path in DEMO_BLOCKED_PATHS:
|
||||
if blocked_path in path:
|
||||
# Determine which category of blocked path
|
||||
if "forecast" in blocked_path:
|
||||
return DEMO_BLOCKED_PATH_MESSAGE["forecasts"]
|
||||
# Can add more categories here in the future
|
||||
return {
|
||||
"message": "Esta funcionalidad no está disponible para cuentas demo.",
|
||||
"message_en": "This functionality is not available for demo accounts."
|
||||
}
|
||||
return None
|
||||
|
||||
def _is_operation_allowed(self, method: str, path: str) -> bool:
|
||||
"""Check if method + path combination is allowed for demo"""
|
||||
|
||||
allowed_paths = DEMO_ALLOWED_OPERATIONS.get(method, [])
|
||||
|
||||
# Check for wildcard
|
||||
if "*" in allowed_paths:
|
||||
return True
|
||||
|
||||
# Check for exact match or pattern match
|
||||
for allowed_path in allowed_paths:
|
||||
if allowed_path.endswith("*"):
|
||||
# Pattern match: /api/orders/* matches /api/orders/123
|
||||
if path.startswith(allowed_path[:-1]):
|
||||
return True
|
||||
elif path == allowed_path:
|
||||
# Exact match
|
||||
return True
|
||||
|
||||
return False
|
||||
57
gateway/app/middleware/logging.py
Normal file
57
gateway/app/middleware/logging.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Logging middleware for gateway
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Logging middleware class"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with logging"""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Generate request ID
|
||||
request_id = str(uuid.uuid4())
|
||||
request.state.request_id = request_id
|
||||
|
||||
# Log request
|
||||
logger.info(
|
||||
f"Request: {request.method} {request.url.path}",
|
||||
extra={
|
||||
"method": request.method,
|
||||
"url": request.url.path,
|
||||
"query_params": str(request.query_params),
|
||||
"client_host": request.client.host if request.client else "unknown",
|
||||
"user_agent": request.headers.get("user-agent", ""),
|
||||
"request_id": request_id
|
||||
}
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate duration
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Log response
|
||||
logger.info(
|
||||
f"Response: {response.status_code} in {duration:.3f}s",
|
||||
extra={
|
||||
"status_code": response.status_code,
|
||||
"duration": duration,
|
||||
"method": request.method,
|
||||
"url": request.url.path,
|
||||
"request_id": request_id
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
93
gateway/app/middleware/rate_limit.py
Normal file
93
gateway/app/middleware/rate_limit.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Rate limiting middleware for gateway
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from typing import Dict, Optional
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting middleware class"""
|
||||
|
||||
def __init__(self, app, calls_per_minute: int = 60):
|
||||
super().__init__(app)
|
||||
self.calls_per_minute = calls_per_minute
|
||||
self.requests: Dict[str, list] = {}
|
||||
self._cleanup_task = None
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with rate limiting"""
|
||||
|
||||
# Skip rate limiting for health checks
|
||||
if request.url.path in ["/health", "/metrics"]:
|
||||
return await call_next(request)
|
||||
|
||||
# Get client identifier
|
||||
client_id = self._get_client_id(request)
|
||||
|
||||
# Check rate limit
|
||||
if self._is_rate_limited(client_id):
|
||||
logger.warning(f"Rate limit exceeded for client: {client_id}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"}
|
||||
)
|
||||
|
||||
# Record request
|
||||
self._record_request(client_id)
|
||||
|
||||
# Process request
|
||||
return await call_next(request)
|
||||
|
||||
def _get_client_id(self, request: Request) -> str:
|
||||
"""Get client identifier"""
|
||||
# Try to get user ID from state (if authenticated)
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
return f"user:{request.state.user.get('user_id', 'unknown')}"
|
||||
|
||||
# Fall back to IP address
|
||||
return f"ip:{request.client.host if request.client else 'unknown'}"
|
||||
|
||||
def _is_rate_limited(self, client_id: str) -> bool:
|
||||
"""Check if client is rate limited"""
|
||||
now = time.time()
|
||||
minute_ago = now - 60
|
||||
|
||||
# Get recent requests for this client
|
||||
if client_id not in self.requests:
|
||||
return False
|
||||
|
||||
# Filter requests from last minute
|
||||
recent_requests = [
|
||||
req_time for req_time in self.requests[client_id]
|
||||
if req_time > minute_ago
|
||||
]
|
||||
|
||||
# Update the list
|
||||
self.requests[client_id] = recent_requests
|
||||
|
||||
# Check if limit exceeded
|
||||
return len(recent_requests) >= self.calls_per_minute
|
||||
|
||||
def _record_request(self, client_id: str):
|
||||
"""Record a request for rate limiting"""
|
||||
now = time.time()
|
||||
|
||||
if client_id not in self.requests:
|
||||
self.requests[client_id] = []
|
||||
|
||||
self.requests[client_id].append(now)
|
||||
|
||||
# Keep only last minute of requests
|
||||
minute_ago = now - 60
|
||||
self.requests[client_id] = [
|
||||
req_time for req_time in self.requests[client_id]
|
||||
if req_time > minute_ago
|
||||
]
|
||||
269
gateway/app/middleware/rate_limiting.py
Normal file
269
gateway/app/middleware/rate_limiting.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
API Rate Limiting Middleware for Gateway
|
||||
Enforces subscription-based API call quotas per hour
|
||||
"""
|
||||
|
||||
import structlog
|
||||
import shared.redis_utils
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import Request, HTTPException, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from typing import Optional
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class APIRateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to enforce API rate limits based on subscription tier.
|
||||
|
||||
Quota limits per hour:
|
||||
- Starter: 100 calls/hour
|
||||
- Professional: 1,000 calls/hour
|
||||
- Enterprise: 10,000 calls/hour
|
||||
|
||||
Uses Redis to track API calls with hourly buckets.
|
||||
"""
|
||||
|
||||
def __init__(self, app, redis_client=None):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
Check API rate limit before processing request.
|
||||
"""
|
||||
# Skip rate limiting for certain paths
|
||||
if self._should_skip_rate_limit(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract tenant_id from request
|
||||
tenant_id = self._extract_tenant_id(request)
|
||||
|
||||
if not tenant_id:
|
||||
# No tenant ID - skip rate limiting for auth/public endpoints
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
# Get subscription tier from headers (added by AuthMiddleware)
|
||||
subscription_tier = request.headers.get("x-subscription-tier")
|
||||
|
||||
if not subscription_tier:
|
||||
# Fallback: get from request state if headers not available
|
||||
subscription_tier = getattr(request.state, "subscription_tier", None)
|
||||
|
||||
if not subscription_tier:
|
||||
# Final fallback: get from tenant service (should rarely happen)
|
||||
subscription_tier = await self._get_subscription_tier(tenant_id, request)
|
||||
logger.warning(f"Subscription tier not found in headers or state, fetched from tenant service: {subscription_tier}")
|
||||
|
||||
# Get quota limit for tier
|
||||
quota_limit = self._get_quota_limit(subscription_tier)
|
||||
|
||||
# Check and increment quota
|
||||
allowed, current_count = await self._check_and_increment_quota(
|
||||
tenant_id,
|
||||
quota_limit
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
"API rate limit exceeded",
|
||||
tenant_id=tenant_id,
|
||||
subscription_tier=subscription_tier,
|
||||
current_count=current_count,
|
||||
quota_limit=quota_limit,
|
||||
path=request.url.path
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": f"API rate limit exceeded. Maximum {quota_limit} calls per hour allowed for {subscription_tier} plan.",
|
||||
"current_count": current_count,
|
||||
"quota_limit": quota_limit,
|
||||
"reset_time": self._get_reset_time(),
|
||||
"upgrade_required": subscription_tier in ['starter', 'professional']
|
||||
}
|
||||
)
|
||||
|
||||
# Add rate limit headers to response
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(quota_limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(max(0, quota_limit - current_count))
|
||||
response.headers["X-RateLimit-Reset"] = self._get_reset_time()
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Rate limiting check failed, allowing request",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
path=request.url.path
|
||||
)
|
||||
# Fail open - allow request if rate limiting fails
|
||||
return await call_next(request)
|
||||
|
||||
def _should_skip_rate_limit(self, path: str) -> bool:
|
||||
"""
|
||||
Determine if path should skip rate limiting.
|
||||
"""
|
||||
skip_paths = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/",
|
||||
"/api/v1/plans", # Public pricing info
|
||||
]
|
||||
|
||||
for skip_path in skip_paths:
|
||||
if path.startswith(skip_path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _extract_tenant_id(self, request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract tenant ID from request headers or path.
|
||||
"""
|
||||
# Try header first
|
||||
tenant_id = request.headers.get("x-tenant-id")
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Try to extract from path /api/v1/tenants/{tenant_id}/...
|
||||
path_parts = request.url.path.split("/")
|
||||
if "tenants" in path_parts:
|
||||
try:
|
||||
tenant_index = path_parts.index("tenants")
|
||||
if len(path_parts) > tenant_index + 1:
|
||||
return path_parts[tenant_index + 1]
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def _get_subscription_tier(self, tenant_id: str, request: Request) -> str:
|
||||
"""
|
||||
Get subscription tier from tenant service (with caching).
|
||||
"""
|
||||
try:
|
||||
# Try to get from request state (if subscription middleware already ran)
|
||||
if hasattr(request.state, "subscription_tier"):
|
||||
return request.state.subscription_tier
|
||||
|
||||
# Call tenant service to get tier
|
||||
import httpx
|
||||
from gateway.app.core.config import settings
|
||||
|
||||
async with httpx.AsyncClient(timeout=2.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
|
||||
headers={
|
||||
"x-service": "gateway"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data.get("tier", "starter")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to get subscription tier, defaulting to starter",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
return "starter"
|
||||
|
||||
def _get_quota_limit(self, subscription_tier: str) -> int:
|
||||
"""
|
||||
Get API calls per hour quota for subscription tier.
|
||||
"""
|
||||
quota_map = {
|
||||
"starter": 100,
|
||||
"professional": 1000,
|
||||
"enterprise": 10000,
|
||||
"demo": 1000, # Same as professional
|
||||
}
|
||||
|
||||
return quota_map.get(subscription_tier.lower(), 100)
|
||||
|
||||
async def _check_and_increment_quota(
|
||||
self,
|
||||
tenant_id: str,
|
||||
quota_limit: int
|
||||
) -> tuple[bool, int]:
|
||||
"""
|
||||
Check current quota usage and increment counter.
|
||||
|
||||
Returns:
|
||||
(allowed: bool, current_count: int)
|
||||
"""
|
||||
if not self.redis_client:
|
||||
# No Redis - fail open
|
||||
return True, 0
|
||||
|
||||
try:
|
||||
# Create hourly bucket key
|
||||
current_hour = datetime.now(timezone.utc).strftime("%Y-%m-%d-%H")
|
||||
quota_key = f"quota:hourly:api_calls:{tenant_id}:{current_hour}"
|
||||
|
||||
# Get current count
|
||||
current_count = await self.redis_client.get(quota_key)
|
||||
current_count = int(current_count) if current_count else 0
|
||||
|
||||
# Check if within limit
|
||||
if current_count >= quota_limit:
|
||||
return False, current_count
|
||||
|
||||
# Increment counter
|
||||
new_count = await self.redis_client.incr(quota_key)
|
||||
|
||||
# Set expiry (1 hour + 5 minutes buffer)
|
||||
await self.redis_client.expire(quota_key, 3900)
|
||||
|
||||
return True, new_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Redis quota check failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
# Fail open
|
||||
return True, 0
|
||||
|
||||
def _get_reset_time(self) -> str:
|
||||
"""
|
||||
Get the reset time for the current hour bucket (top of next hour).
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
next_hour = (now + timedelta(hours=1)).replace(minute=0, second=0, microsecond=0)
|
||||
|
||||
return next_hour.isoformat()
|
||||
|
||||
|
||||
async def get_rate_limit_middleware(app):
|
||||
"""
|
||||
Factory function to create rate limiting middleware with Redis client.
|
||||
"""
|
||||
try:
|
||||
from gateway.app.core.config import settings
|
||||
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
||||
|
||||
logger.info("API rate limiting middleware initialized with Redis")
|
||||
return APIRateLimitMiddleware(app, redis_client=redis_client)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to initialize Redis for rate limiting, middleware will fail open",
|
||||
error=str(e)
|
||||
)
|
||||
return APIRateLimitMiddleware(app, redis_client=None)
|
||||
149
gateway/app/middleware/read_only_mode.py
Normal file
149
gateway/app/middleware/read_only_mode.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Gateway middleware to enforce read-only mode for subscriptions with status:
|
||||
- pending_cancellation (until cancellation_effective_date)
|
||||
- inactive (after cancellation or no active subscription)
|
||||
|
||||
Allowed operations in read-only mode:
|
||||
- GET requests (all read operations)
|
||||
- POST /api/v1/users/me/delete/request (account deletion)
|
||||
- POST /api/v1/subscriptions/reactivate (subscription reactivation)
|
||||
- POST /api/v1/subscriptions/* (subscription management)
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import logging
|
||||
from fastapi import Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from typing import Optional
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Whitelist of POST/PUT/DELETE endpoints allowed in read-only mode
|
||||
READ_ONLY_WHITELIST_PATTERNS = [
|
||||
r'^/api/v1/users/me/delete/request$',
|
||||
r'^/api/v1/users/me/export.*$',
|
||||
r'^/api/v1/tenants/.*/subscription/.*', # All tenant subscription endpoints
|
||||
r'^/api/v1/registration/.*', # Registration flow endpoints
|
||||
r'^/api/v1/auth/.*', # Allow auth operations
|
||||
r'^/api/v1/tenants/register$', # Allow new tenant registration (no existing tenant context)
|
||||
r'^/api/v1/tenants/.*/orchestrator/run-daily-workflow$', # Allow workflow testing
|
||||
r'^/api/v1/tenants/.*/inventory/ml/insights/.*', # Allow ML insights (safety stock optimization)
|
||||
r'^/api/v1/tenants/.*/production/ml/insights/.*', # Allow ML insights (yield prediction)
|
||||
r'^/api/v1/tenants/.*/procurement/ml/insights/.*', # Allow ML insights (supplier analysis, price forecasting)
|
||||
r'^/api/v1/tenants/.*/forecasting/ml/insights/.*', # Allow ML insights (rules generation)
|
||||
r'^/api/v1/tenants/.*/forecasting/operations/.*', # Allow forecasting operations
|
||||
r'^/api/v1/webhooks/.*', # Webhook endpoints - no tenant context
|
||||
]
|
||||
|
||||
|
||||
class ReadOnlyModeMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to enforce read-only mode based on subscription status
|
||||
"""
|
||||
|
||||
def __init__(self, app, tenant_service_url: str = "http://tenant-service:8000"):
|
||||
super().__init__(app)
|
||||
self.tenant_service_url = tenant_service_url
|
||||
self.cache = {}
|
||||
self.cache_ttl = 60
|
||||
|
||||
async def check_subscription_status(self, tenant_id: str, authorization: str) -> dict:
|
||||
"""
|
||||
Check subscription status from tenant service
|
||||
Returns subscription data including status and read_only flag
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.tenant_service_url}/api/v1/tenants/{tenant_id}/subscription/status",
|
||||
headers={"Authorization": authorization}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
elif response.status_code == 404:
|
||||
return {"status": "inactive", "is_read_only": True}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to check subscription status: {response.status_code}",
|
||||
extra={"tenant_id": tenant_id}
|
||||
)
|
||||
return {"status": "unknown", "is_read_only": False}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error checking subscription status: {e}",
|
||||
extra={"tenant_id": tenant_id}
|
||||
)
|
||||
return {"status": "unknown", "is_read_only": False}
|
||||
|
||||
def is_whitelisted_endpoint(self, path: str) -> bool:
|
||||
"""
|
||||
Check if endpoint is whitelisted for read-only mode
|
||||
"""
|
||||
for pattern in READ_ONLY_WHITELIST_PATTERNS:
|
||||
if re.match(pattern, path):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_write_operation(self, method: str) -> bool:
|
||||
"""
|
||||
Determine if HTTP method is a write operation
|
||||
"""
|
||||
return method.upper() in ['POST', 'PUT', 'DELETE', 'PATCH']
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
Process each request through read-only mode check
|
||||
"""
|
||||
tenant_id = request.headers.get("X-Tenant-ID")
|
||||
authorization = request.headers.get("Authorization")
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
|
||||
if not tenant_id or not authorization:
|
||||
return await call_next(request)
|
||||
|
||||
if method.upper() == 'GET':
|
||||
return await call_next(request)
|
||||
|
||||
if self.is_whitelisted_endpoint(path):
|
||||
return await call_next(request)
|
||||
|
||||
if self.is_write_operation(method):
|
||||
subscription_data = await self.check_subscription_status(tenant_id, authorization)
|
||||
|
||||
if subscription_data.get("is_read_only", False):
|
||||
status_detail = subscription_data.get("status", "inactive")
|
||||
effective_date = subscription_data.get("cancellation_effective_date")
|
||||
|
||||
error_message = {
|
||||
"detail": "Account is in read-only mode",
|
||||
"reason": f"Subscription status: {status_detail}",
|
||||
"message": "Your subscription has been cancelled. You can view data but cannot make changes.",
|
||||
"action_required": "Reactivate your subscription to regain full access",
|
||||
"reactivation_url": "/app/settings/subscription"
|
||||
}
|
||||
|
||||
if effective_date:
|
||||
error_message["read_only_until"] = effective_date
|
||||
error_message["message"] = f"Your subscription is pending cancellation. Read-only mode starts on {effective_date}."
|
||||
|
||||
logger.info(
|
||||
"read_only_mode_enforced",
|
||||
extra={
|
||||
"tenant_id": tenant_id,
|
||||
"path": path,
|
||||
"method": method,
|
||||
"subscription_status": status_detail
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content=error_message
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
83
gateway/app/middleware/request_id.py
Normal file
83
gateway/app/middleware/request_id.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Request ID Middleware for distributed tracing
|
||||
Generates and propagates unique request IDs across all services
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import structlog
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to generate and propagate request IDs for distributed tracing.
|
||||
|
||||
Request IDs are:
|
||||
- Generated if not provided by client
|
||||
- Logged with every request
|
||||
- Propagated to all downstream services
|
||||
- Returned in response headers
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with request ID tracking"""
|
||||
|
||||
# Extract or generate request ID
|
||||
request_id = request.headers.get("X-Request-ID")
|
||||
if not request_id:
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
# Store in request state for access by routes
|
||||
request.state.request_id = request_id
|
||||
|
||||
# Bind request ID to structured logger context
|
||||
logger_ctx = logger.bind(request_id=request_id)
|
||||
|
||||
# Inject request ID header for downstream services using HeaderManager
|
||||
# Note: This runs early in middleware chain, so we use add_header_for_middleware
|
||||
header_manager.add_header_for_middleware(request, "x-request-id", request_id)
|
||||
|
||||
# Log request start
|
||||
logger_ctx.info(
|
||||
"Request started",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
client_ip=request.client.host if request.client else None
|
||||
)
|
||||
|
||||
try:
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add request ID to response headers
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
|
||||
# Log request completion
|
||||
logger_ctx.info(
|
||||
"Request completed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
status_code=response.status_code
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Log request failure
|
||||
logger_ctx.error(
|
||||
"Request failed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
raise
|
||||
462
gateway/app/middleware/subscription.py
Normal file
462
gateway/app/middleware/subscription.py
Normal file
@@ -0,0 +1,462 @@
|
||||
"""
|
||||
Subscription Middleware - Enforces subscription limits and feature access
|
||||
Updated to support standardized URL structure with tier-based access control
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
import structlog
|
||||
from fastapi import Request, Response, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional, List
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from app.utils.subscription_error_responses import create_upgrade_required_response
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SubscriptionMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to enforce subscription-based access control
|
||||
|
||||
Supports standardized URL structure:
|
||||
- Base routes (/api/v1/tenants/{tenant_id}/{service}/{resource}): ALL tiers
|
||||
- Dashboard routes (/api/v1/tenants/{tenant_id}/{service}/dashboard/*): ALL tiers
|
||||
- Analytics routes (/api/v1/tenants/{tenant_id}/{service}/analytics/*): PROFESSIONAL+
|
||||
- Operations routes (/api/v1/tenants/{tenant_id}/{service}/operations/*): ALL tiers (role-based)
|
||||
"""
|
||||
|
||||
def __init__(self, app, tenant_service_url: str, redis_client=None):
|
||||
super().__init__(app)
|
||||
self.tenant_service_url = tenant_service_url.rstrip('/')
|
||||
self.redis_client = redis_client # Optional Redis client for abuse detection
|
||||
|
||||
# Define route patterns that require subscription validation
|
||||
# Using new standardized URL structure
|
||||
self.protected_routes = {
|
||||
# ===== ANALYTICS ROUTES - PROFESSIONAL/ENTERPRISE ONLY =====
|
||||
# Any service analytics endpoint
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/analytics/.*': {
|
||||
'feature': 'analytics',
|
||||
'minimum_tier': 'professional',
|
||||
'allowed_tiers': ['professional', 'enterprise'],
|
||||
'description': 'Analytics features (Professional/Enterprise only)'
|
||||
},
|
||||
|
||||
# ===== TRAINING SERVICE - ALL TIERS =====
|
||||
r'^/api/v1/tenants/[^/]+/training/.*': {
|
||||
'feature': 'ml_training',
|
||||
'minimum_tier': 'basic',
|
||||
'allowed_tiers': ['basic', 'professional', 'enterprise'],
|
||||
'description': 'Machine learning model training (Available for all tiers)'
|
||||
},
|
||||
|
||||
# ===== ADVANCED FEATURES - PROFESSIONAL/ENTERPRISE =====
|
||||
# Advanced reporting and exports
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/export/advanced.*': {
|
||||
'feature': 'advanced_exports',
|
||||
'minimum_tier': 'professional',
|
||||
'allowed_tiers': ['professional', 'enterprise'],
|
||||
'description': 'Advanced export formats (Professional/Enterprise only)'
|
||||
},
|
||||
|
||||
# Bulk operations
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/bulk/.*': {
|
||||
'feature': 'bulk_operations',
|
||||
'minimum_tier': 'professional',
|
||||
'allowed_tiers': ['professional', 'enterprise'],
|
||||
'description': 'Bulk operations (Professional/Enterprise only)'
|
||||
},
|
||||
}
|
||||
|
||||
# Routes that are explicitly allowed for all tiers (no check needed)
|
||||
self.public_tier_routes = [
|
||||
# Base CRUD operations - ALL TIERS
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/?$',
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/[^/]+/?$',
|
||||
|
||||
# Dashboard routes - ALL TIERS
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/dashboard/.*',
|
||||
|
||||
# Operations routes - ALL TIERS (role-based control applies)
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/operations/.*',
|
||||
]
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process the request and check subscription requirements"""
|
||||
|
||||
# Skip subscription check for certain routes
|
||||
if self._should_skip_subscription_check(request):
|
||||
return await call_next(request)
|
||||
|
||||
# Check if route is explicitly allowed for all tiers
|
||||
if self._is_public_tier_route(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Check if route requires subscription validation
|
||||
subscription_requirement = self._get_subscription_requirement(request.url.path)
|
||||
if not subscription_requirement:
|
||||
return await call_next(request)
|
||||
|
||||
# Get tenant ID from request
|
||||
tenant_id = self._extract_tenant_id(request)
|
||||
if not tenant_id:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error": "subscription_validation_failed",
|
||||
"message": "Tenant ID required for subscription validation",
|
||||
"code": "MISSING_TENANT_ID"
|
||||
}
|
||||
)
|
||||
|
||||
# Validate subscription with new tier-based system
|
||||
validation_result = await self._validate_subscription_tier(
|
||||
request,
|
||||
tenant_id,
|
||||
subscription_requirement.get('feature'),
|
||||
subscription_requirement.get('minimum_tier'),
|
||||
subscription_requirement.get('allowed_tiers', [])
|
||||
)
|
||||
|
||||
if not validation_result['allowed']:
|
||||
# Use enhanced error response with conversion optimization
|
||||
feature = subscription_requirement.get('feature')
|
||||
current_tier = validation_result.get('current_tier', 'unknown')
|
||||
required_tier = subscription_requirement.get('minimum_tier')
|
||||
allowed_tiers = subscription_requirement.get('allowed_tiers', [])
|
||||
|
||||
# Create conversion-optimized error response
|
||||
enhanced_response = create_upgrade_required_response(
|
||||
feature=feature,
|
||||
current_tier=current_tier,
|
||||
required_tier=required_tier,
|
||||
allowed_tiers=allowed_tiers,
|
||||
custom_message=validation_result.get('message')
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=enhanced_response.status_code,
|
||||
content=enhanced_response.dict()
|
||||
)
|
||||
|
||||
# Subscription validation passed, continue with request
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
def _is_public_tier_route(self, path: str) -> bool:
|
||||
"""
|
||||
Check if route is explicitly allowed for all subscription tiers
|
||||
|
||||
Args:
|
||||
path: Request path
|
||||
|
||||
Returns:
|
||||
True if route is allowed for all tiers
|
||||
"""
|
||||
for pattern in self.public_tier_routes:
|
||||
if re.match(pattern, path):
|
||||
logger.debug("Route allowed for all tiers", path=path, pattern=pattern)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _should_skip_subscription_check(self, request: Request) -> bool:
|
||||
"""Check if subscription validation should be skipped"""
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
|
||||
# Skip for health checks, auth, and public routes
|
||||
skip_patterns = [
|
||||
r'/health.*',
|
||||
r'/metrics.*',
|
||||
r'/api/v1/auth/.*',
|
||||
r'/api/v1/tenants/[^/]+/subscription/.*', # All tenant subscription endpoints
|
||||
r'/api/v1/registration/.*', # Registration flow endpoints
|
||||
r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info
|
||||
r'/api/v1/webhooks/.*', # Webhook endpoints - no tenant context
|
||||
r'/docs.*',
|
||||
r'/openapi\.json',
|
||||
# Training monitoring endpoints (WebSocket and status checks)
|
||||
r'/api/v1/tenants/[^/]+/training/jobs/.*/live.*', # WebSocket endpoint
|
||||
r'/api/v1/tenants/[^/]+/training/jobs/.*/status.*', # Status polling endpoint
|
||||
]
|
||||
|
||||
# Skip OPTIONS requests (CORS preflight)
|
||||
if method == "OPTIONS":
|
||||
return True
|
||||
|
||||
for pattern in skip_patterns:
|
||||
if re.match(pattern, path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_subscription_requirement(self, path: str) -> Optional[Dict[str, str]]:
|
||||
"""Get subscription requirement for a given path"""
|
||||
for pattern, requirement in self.protected_routes.items():
|
||||
if re.match(pattern, path):
|
||||
return requirement
|
||||
return None
|
||||
|
||||
def _extract_tenant_id(self, request: Request) -> Optional[str]:
|
||||
"""Extract tenant ID from request"""
|
||||
# Try to get from URL path first
|
||||
path_match = re.search(r'/api/v1/tenants/([^/]+)/', request.url.path)
|
||||
if path_match:
|
||||
return path_match.group(1)
|
||||
|
||||
# Try to get from headers
|
||||
tenant_id = request.headers.get('x-tenant-id')
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Try to get from user state (set by auth middleware)
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
return request.state.user.get('tenant_id')
|
||||
|
||||
return None
|
||||
|
||||
async def _validate_subscription_tier(
|
||||
self,
|
||||
request: Request,
|
||||
tenant_id: str,
|
||||
feature: Optional[str],
|
||||
minimum_tier: str,
|
||||
allowed_tiers: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate subscription tier access using cached subscription lookup
|
||||
|
||||
Args:
|
||||
request: FastAPI request
|
||||
tenant_id: Tenant ID
|
||||
feature: Feature name (optional, for additional checks)
|
||||
minimum_tier: Minimum required subscription tier
|
||||
allowed_tiers: List of allowed subscription tiers
|
||||
|
||||
Returns:
|
||||
Dict with 'allowed' boolean and additional metadata
|
||||
"""
|
||||
try:
|
||||
# Check if JWT already has subscription
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
user_context = request.state.user
|
||||
user_id = user_context.get('user_id', 'unknown')
|
||||
|
||||
if user_context.get("subscription_from_jwt"):
|
||||
# Use JWT data directly - NO HTTP CALL!
|
||||
current_tier = user_context.get("subscription_tier", "starter")
|
||||
|
||||
logger.debug("Using subscription tier from JWT (no HTTP call)",
|
||||
tenant_id=tenant_id,
|
||||
current_tier=current_tier,
|
||||
minimum_tier=minimum_tier,
|
||||
allowed_tiers=allowed_tiers)
|
||||
|
||||
if current_tier not in [tier.lower() for tier in allowed_tiers]:
|
||||
tier_names = ', '.join(allowed_tiers)
|
||||
return {
|
||||
'allowed': False,
|
||||
'message': f'This feature requires a {tier_names} subscription plan',
|
||||
'current_tier': current_tier
|
||||
}
|
||||
await self._log_subscription_access(
|
||||
tenant_id,
|
||||
user_id,
|
||||
feature,
|
||||
current_tier,
|
||||
True,
|
||||
"jwt"
|
||||
)
|
||||
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (JWT subscription)',
|
||||
'current_tier': current_tier
|
||||
}
|
||||
|
||||
# Use unified HeaderManager for consistent header handling
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Extract user_id for logging (fallback path)
|
||||
user_id = header_manager.get_header_value(request, 'x-user-id', 'unknown')
|
||||
|
||||
# Call tenant service fast tier endpoint with caching (fallback for old tokens)
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=1.0, # Connection timeout - very short for cached endpoint
|
||||
read=5.0, # Read timeout - short for cached lookup
|
||||
write=1.0, # Write timeout
|
||||
pool=1.0 # Pool timeout
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||
# Use fast cached tier endpoint (new URL pattern)
|
||||
tier_response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
if tier_response.status_code != 200:
|
||||
logger.warning(
|
||||
"Failed to get subscription tier from cache",
|
||||
tenant_id=tenant_id,
|
||||
status_code=tier_response.status_code,
|
||||
response_text=tier_response.text
|
||||
)
|
||||
# Fail open for availability
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (validation service unavailable)',
|
||||
'current_tier': 'unknown'
|
||||
}
|
||||
|
||||
tier_data = tier_response.json()
|
||||
current_tier = tier_data.get('tier', 'starter').lower()
|
||||
|
||||
logger.debug("Subscription tier check (cached)",
|
||||
tenant_id=tenant_id,
|
||||
current_tier=current_tier,
|
||||
minimum_tier=minimum_tier,
|
||||
allowed_tiers=allowed_tiers,
|
||||
cached=tier_data.get('cached', False))
|
||||
|
||||
# Check if current tier is in allowed tiers
|
||||
if current_tier not in [tier.lower() for tier in allowed_tiers]:
|
||||
tier_names = ', '.join(allowed_tiers)
|
||||
await self._log_subscription_access(
|
||||
tenant_id,
|
||||
user_id,
|
||||
feature,
|
||||
current_tier,
|
||||
False,
|
||||
"jwt"
|
||||
)
|
||||
|
||||
return {
|
||||
'allowed': False,
|
||||
'message': f'This feature requires a {tier_names} subscription plan',
|
||||
'current_tier': current_tier
|
||||
}
|
||||
|
||||
# Tier check passed
|
||||
await self._log_subscription_access(
|
||||
tenant_id,
|
||||
user_id,
|
||||
feature,
|
||||
current_tier,
|
||||
True,
|
||||
"database"
|
||||
)
|
||||
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted',
|
||||
'current_tier': current_tier
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"Timeout validating subscription",
|
||||
tenant_id=tenant_id,
|
||||
feature=feature
|
||||
)
|
||||
# Fail open for availability (let service handle detailed check)
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (validation timeout)',
|
||||
'current_plan': 'unknown'
|
||||
}
|
||||
except httpx.RequestError as e:
|
||||
logger.error(
|
||||
"Request error validating subscription",
|
||||
tenant_id=tenant_id,
|
||||
feature=feature,
|
||||
error=str(e)
|
||||
)
|
||||
# Fail open for availability
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (validation service unavailable)',
|
||||
'current_plan': 'unknown'
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Subscription validation error",
|
||||
tenant_id=tenant_id,
|
||||
feature=feature,
|
||||
error=str(e)
|
||||
)
|
||||
# Fail open for availability (let service handle detailed check)
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (validation error)',
|
||||
'current_plan': 'unknown'
|
||||
}
|
||||
|
||||
async def _log_subscription_access(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
requested_feature: str,
|
||||
current_tier: str,
|
||||
access_granted: bool,
|
||||
source: str # "jwt" or "database"
|
||||
):
|
||||
"""
|
||||
Log all subscription-gated access attempts for audit and anomaly detection.
|
||||
"""
|
||||
logger.info(
|
||||
"Subscription access check",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
feature=requested_feature,
|
||||
tier=current_tier,
|
||||
granted=access_granted,
|
||||
source=source,
|
||||
timestamp=datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
|
||||
# For denied access, check for suspicious patterns
|
||||
if not access_granted:
|
||||
await self._check_for_abuse_patterns(tenant_id, user_id, requested_feature)
|
||||
|
||||
async def _check_for_abuse_patterns(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
feature: str
|
||||
):
|
||||
"""
|
||||
Detect potential abuse patterns like repeated premium feature access attempts.
|
||||
"""
|
||||
if not self.redis_client:
|
||||
return
|
||||
|
||||
# Track denied attempts in a sliding window
|
||||
key = f"subscription_denied:{tenant_id}:{user_id}:{feature}"
|
||||
|
||||
try:
|
||||
attempts = await self.redis_client.incr(key)
|
||||
if attempts == 1:
|
||||
await self.redis_client.expire(key, 3600) # 1 hour window
|
||||
|
||||
# Alert if too many denied attempts (potential bypass attempt)
|
||||
if attempts > 10:
|
||||
logger.warning(
|
||||
"SECURITY: Excessive premium feature access attempts detected",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
feature=feature,
|
||||
attempts=attempts,
|
||||
window="1 hour"
|
||||
)
|
||||
# Could trigger alert to security team here
|
||||
except Exception as e:
|
||||
logger.warning("Failed to track abuse patterns", error=str(e))
|
||||
|
||||
Reference in New Issue
Block a user