New token arch
This commit is contained in:
@@ -12,6 +12,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional, List
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.core.config import settings
|
||||
from app.utils.subscription_error_responses import create_upgrade_required_response
|
||||
@@ -30,9 +31,10 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
|
||||
- Operations routes (/api/v1/tenants/{tenant_id}/{service}/operations/*): ALL tiers (role-based)
|
||||
"""
|
||||
|
||||
def __init__(self, app, tenant_service_url: str):
|
||||
def __init__(self, app, tenant_service_url: str, redis_client=None):
|
||||
super().__init__(app)
|
||||
self.tenant_service_url = tenant_service_url.rstrip('/')
|
||||
self.redis_client = redis_client # Optional Redis client for abuse detection
|
||||
|
||||
# Define route patterns that require subscription validation
|
||||
# Using new standardized URL structure
|
||||
@@ -236,20 +238,60 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
|
||||
Dict with 'allowed' boolean and additional metadata
|
||||
"""
|
||||
try:
|
||||
# Use the same authentication pattern as gateway routes
|
||||
# Check if JWT already has subscription
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
user_context = request.state.user
|
||||
user_id = user_context.get('user_id', 'unknown')
|
||||
|
||||
if user_context.get("subscription_from_jwt"):
|
||||
# Use JWT data directly - NO HTTP CALL!
|
||||
current_tier = user_context.get("subscription_tier", "starter")
|
||||
|
||||
logger.debug("Using subscription tier from JWT (no HTTP call)",
|
||||
tenant_id=tenant_id,
|
||||
current_tier=current_tier,
|
||||
minimum_tier=minimum_tier,
|
||||
allowed_tiers=allowed_tiers)
|
||||
|
||||
if current_tier not in [tier.lower() for tier in allowed_tiers]:
|
||||
tier_names = ', '.join(allowed_tiers)
|
||||
return {
|
||||
'allowed': False,
|
||||
'message': f'This feature requires a {tier_names} subscription plan',
|
||||
'current_tier': current_tier
|
||||
}
|
||||
await self._log_subscription_access(
|
||||
tenant_id,
|
||||
user_id,
|
||||
feature,
|
||||
current_tier,
|
||||
True,
|
||||
"jwt"
|
||||
)
|
||||
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (JWT subscription)',
|
||||
'current_tier': current_tier
|
||||
}
|
||||
|
||||
# Use the same authentication pattern as gateway routes for fallback
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
|
||||
# Extract user_id for logging (fallback path)
|
||||
user_id = 'unknown'
|
||||
# Add user context headers if available
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
user = request.state.user
|
||||
headers["x-user-id"] = str(user.get('user_id', ''))
|
||||
user_id = str(user.get('user_id', 'unknown'))
|
||||
headers["x-user-id"] = user_id
|
||||
headers["x-user-email"] = str(user.get('email', ''))
|
||||
headers["x-user-role"] = str(user.get('role', 'user'))
|
||||
headers["x-user-full-name"] = str(user.get('full_name', ''))
|
||||
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
|
||||
|
||||
# Call tenant service fast tier endpoint with caching
|
||||
# Call tenant service fast tier endpoint with caching (fallback for old tokens)
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=1.0, # Connection timeout - very short for cached endpoint
|
||||
read=5.0, # Read timeout - short for cached lookup
|
||||
@@ -291,6 +333,15 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
|
||||
# Check if current tier is in allowed tiers
|
||||
if current_tier not in [tier.lower() for tier in allowed_tiers]:
|
||||
tier_names = ', '.join(allowed_tiers)
|
||||
await self._log_subscription_access(
|
||||
tenant_id,
|
||||
user_id,
|
||||
feature,
|
||||
current_tier,
|
||||
False,
|
||||
"jwt"
|
||||
)
|
||||
|
||||
return {
|
||||
'allowed': False,
|
||||
'message': f'This feature requires a {tier_names} subscription plan',
|
||||
@@ -298,6 +349,15 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
|
||||
}
|
||||
|
||||
# Tier check passed
|
||||
await self._log_subscription_access(
|
||||
tenant_id,
|
||||
user_id,
|
||||
feature,
|
||||
current_tier,
|
||||
True,
|
||||
"database"
|
||||
)
|
||||
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted',
|
||||
@@ -343,3 +403,64 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
|
||||
'current_plan': 'unknown'
|
||||
}
|
||||
|
||||
async def _log_subscription_access(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
requested_feature: str,
|
||||
current_tier: str,
|
||||
access_granted: bool,
|
||||
source: str # "jwt" or "database"
|
||||
):
|
||||
"""
|
||||
Log all subscription-gated access attempts for audit and anomaly detection.
|
||||
"""
|
||||
logger.info(
|
||||
"Subscription access check",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
feature=requested_feature,
|
||||
tier=current_tier,
|
||||
granted=access_granted,
|
||||
source=source,
|
||||
timestamp=datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
|
||||
# For denied access, check for suspicious patterns
|
||||
if not access_granted:
|
||||
await self._check_for_abuse_patterns(tenant_id, user_id, requested_feature)
|
||||
|
||||
async def _check_for_abuse_patterns(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
feature: str
|
||||
):
|
||||
"""
|
||||
Detect potential abuse patterns like repeated premium feature access attempts.
|
||||
"""
|
||||
if not self.redis_client:
|
||||
return
|
||||
|
||||
# Track denied attempts in a sliding window
|
||||
key = f"subscription_denied:{tenant_id}:{user_id}:{feature}"
|
||||
|
||||
try:
|
||||
attempts = await self.redis_client.incr(key)
|
||||
if attempts == 1:
|
||||
await self.redis_client.expire(key, 3600) # 1 hour window
|
||||
|
||||
# Alert if too many denied attempts (potential bypass attempt)
|
||||
if attempts > 10:
|
||||
logger.warning(
|
||||
"SECURITY: Excessive premium feature access attempts detected",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
feature=feature,
|
||||
attempts=attempts,
|
||||
window="1 hour"
|
||||
)
|
||||
# Could trigger alert to security team here
|
||||
except Exception as e:
|
||||
logger.warning("Failed to track abuse patterns", error=str(e))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user