Initial commit - production deployment
This commit is contained in:
462
gateway/app/middleware/subscription.py
Normal file
462
gateway/app/middleware/subscription.py
Normal file
@@ -0,0 +1,462 @@
|
||||
"""
|
||||
Subscription Middleware - Enforces subscription limits and feature access
|
||||
Updated to support standardized URL structure with tier-based access control
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
import structlog
|
||||
from fastapi import Request, Response, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional, List
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from app.utils.subscription_error_responses import create_upgrade_required_response
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SubscriptionMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to enforce subscription-based access control
|
||||
|
||||
Supports standardized URL structure:
|
||||
- Base routes (/api/v1/tenants/{tenant_id}/{service}/{resource}): ALL tiers
|
||||
- Dashboard routes (/api/v1/tenants/{tenant_id}/{service}/dashboard/*): ALL tiers
|
||||
- Analytics routes (/api/v1/tenants/{tenant_id}/{service}/analytics/*): PROFESSIONAL+
|
||||
- Operations routes (/api/v1/tenants/{tenant_id}/{service}/operations/*): ALL tiers (role-based)
|
||||
"""
|
||||
|
||||
def __init__(self, app, tenant_service_url: str, redis_client=None):
|
||||
super().__init__(app)
|
||||
self.tenant_service_url = tenant_service_url.rstrip('/')
|
||||
self.redis_client = redis_client # Optional Redis client for abuse detection
|
||||
|
||||
# Define route patterns that require subscription validation
|
||||
# Using new standardized URL structure
|
||||
self.protected_routes = {
|
||||
# ===== ANALYTICS ROUTES - PROFESSIONAL/ENTERPRISE ONLY =====
|
||||
# Any service analytics endpoint
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/analytics/.*': {
|
||||
'feature': 'analytics',
|
||||
'minimum_tier': 'professional',
|
||||
'allowed_tiers': ['professional', 'enterprise'],
|
||||
'description': 'Analytics features (Professional/Enterprise only)'
|
||||
},
|
||||
|
||||
# ===== TRAINING SERVICE - ALL TIERS =====
|
||||
r'^/api/v1/tenants/[^/]+/training/.*': {
|
||||
'feature': 'ml_training',
|
||||
'minimum_tier': 'basic',
|
||||
'allowed_tiers': ['basic', 'professional', 'enterprise'],
|
||||
'description': 'Machine learning model training (Available for all tiers)'
|
||||
},
|
||||
|
||||
# ===== ADVANCED FEATURES - PROFESSIONAL/ENTERPRISE =====
|
||||
# Advanced reporting and exports
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/export/advanced.*': {
|
||||
'feature': 'advanced_exports',
|
||||
'minimum_tier': 'professional',
|
||||
'allowed_tiers': ['professional', 'enterprise'],
|
||||
'description': 'Advanced export formats (Professional/Enterprise only)'
|
||||
},
|
||||
|
||||
# Bulk operations
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/bulk/.*': {
|
||||
'feature': 'bulk_operations',
|
||||
'minimum_tier': 'professional',
|
||||
'allowed_tiers': ['professional', 'enterprise'],
|
||||
'description': 'Bulk operations (Professional/Enterprise only)'
|
||||
},
|
||||
}
|
||||
|
||||
# Routes that are explicitly allowed for all tiers (no check needed)
|
||||
self.public_tier_routes = [
|
||||
# Base CRUD operations - ALL TIERS
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/?$',
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/[^/]+/?$',
|
||||
|
||||
# Dashboard routes - ALL TIERS
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/dashboard/.*',
|
||||
|
||||
# Operations routes - ALL TIERS (role-based control applies)
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/operations/.*',
|
||||
]
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process the request and check subscription requirements"""
|
||||
|
||||
# Skip subscription check for certain routes
|
||||
if self._should_skip_subscription_check(request):
|
||||
return await call_next(request)
|
||||
|
||||
# Check if route is explicitly allowed for all tiers
|
||||
if self._is_public_tier_route(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Check if route requires subscription validation
|
||||
subscription_requirement = self._get_subscription_requirement(request.url.path)
|
||||
if not subscription_requirement:
|
||||
return await call_next(request)
|
||||
|
||||
# Get tenant ID from request
|
||||
tenant_id = self._extract_tenant_id(request)
|
||||
if not tenant_id:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error": "subscription_validation_failed",
|
||||
"message": "Tenant ID required for subscription validation",
|
||||
"code": "MISSING_TENANT_ID"
|
||||
}
|
||||
)
|
||||
|
||||
# Validate subscription with new tier-based system
|
||||
validation_result = await self._validate_subscription_tier(
|
||||
request,
|
||||
tenant_id,
|
||||
subscription_requirement.get('feature'),
|
||||
subscription_requirement.get('minimum_tier'),
|
||||
subscription_requirement.get('allowed_tiers', [])
|
||||
)
|
||||
|
||||
if not validation_result['allowed']:
|
||||
# Use enhanced error response with conversion optimization
|
||||
feature = subscription_requirement.get('feature')
|
||||
current_tier = validation_result.get('current_tier', 'unknown')
|
||||
required_tier = subscription_requirement.get('minimum_tier')
|
||||
allowed_tiers = subscription_requirement.get('allowed_tiers', [])
|
||||
|
||||
# Create conversion-optimized error response
|
||||
enhanced_response = create_upgrade_required_response(
|
||||
feature=feature,
|
||||
current_tier=current_tier,
|
||||
required_tier=required_tier,
|
||||
allowed_tiers=allowed_tiers,
|
||||
custom_message=validation_result.get('message')
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=enhanced_response.status_code,
|
||||
content=enhanced_response.dict()
|
||||
)
|
||||
|
||||
# Subscription validation passed, continue with request
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
def _is_public_tier_route(self, path: str) -> bool:
|
||||
"""
|
||||
Check if route is explicitly allowed for all subscription tiers
|
||||
|
||||
Args:
|
||||
path: Request path
|
||||
|
||||
Returns:
|
||||
True if route is allowed for all tiers
|
||||
"""
|
||||
for pattern in self.public_tier_routes:
|
||||
if re.match(pattern, path):
|
||||
logger.debug("Route allowed for all tiers", path=path, pattern=pattern)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _should_skip_subscription_check(self, request: Request) -> bool:
|
||||
"""Check if subscription validation should be skipped"""
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
|
||||
# Skip for health checks, auth, and public routes
|
||||
skip_patterns = [
|
||||
r'/health.*',
|
||||
r'/metrics.*',
|
||||
r'/api/v1/auth/.*',
|
||||
r'/api/v1/tenants/[^/]+/subscription/.*', # All tenant subscription endpoints
|
||||
r'/api/v1/registration/.*', # Registration flow endpoints
|
||||
r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info
|
||||
r'/api/v1/webhooks/.*', # Webhook endpoints - no tenant context
|
||||
r'/docs.*',
|
||||
r'/openapi\.json',
|
||||
# Training monitoring endpoints (WebSocket and status checks)
|
||||
r'/api/v1/tenants/[^/]+/training/jobs/.*/live.*', # WebSocket endpoint
|
||||
r'/api/v1/tenants/[^/]+/training/jobs/.*/status.*', # Status polling endpoint
|
||||
]
|
||||
|
||||
# Skip OPTIONS requests (CORS preflight)
|
||||
if method == "OPTIONS":
|
||||
return True
|
||||
|
||||
for pattern in skip_patterns:
|
||||
if re.match(pattern, path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_subscription_requirement(self, path: str) -> Optional[Dict[str, str]]:
|
||||
"""Get subscription requirement for a given path"""
|
||||
for pattern, requirement in self.protected_routes.items():
|
||||
if re.match(pattern, path):
|
||||
return requirement
|
||||
return None
|
||||
|
||||
def _extract_tenant_id(self, request: Request) -> Optional[str]:
|
||||
"""Extract tenant ID from request"""
|
||||
# Try to get from URL path first
|
||||
path_match = re.search(r'/api/v1/tenants/([^/]+)/', request.url.path)
|
||||
if path_match:
|
||||
return path_match.group(1)
|
||||
|
||||
# Try to get from headers
|
||||
tenant_id = request.headers.get('x-tenant-id')
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Try to get from user state (set by auth middleware)
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
return request.state.user.get('tenant_id')
|
||||
|
||||
return None
|
||||
|
||||
async def _validate_subscription_tier(
|
||||
self,
|
||||
request: Request,
|
||||
tenant_id: str,
|
||||
feature: Optional[str],
|
||||
minimum_tier: str,
|
||||
allowed_tiers: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate subscription tier access using cached subscription lookup
|
||||
|
||||
Args:
|
||||
request: FastAPI request
|
||||
tenant_id: Tenant ID
|
||||
feature: Feature name (optional, for additional checks)
|
||||
minimum_tier: Minimum required subscription tier
|
||||
allowed_tiers: List of allowed subscription tiers
|
||||
|
||||
Returns:
|
||||
Dict with 'allowed' boolean and additional metadata
|
||||
"""
|
||||
try:
|
||||
# Check if JWT already has subscription
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
user_context = request.state.user
|
||||
user_id = user_context.get('user_id', 'unknown')
|
||||
|
||||
if user_context.get("subscription_from_jwt"):
|
||||
# Use JWT data directly - NO HTTP CALL!
|
||||
current_tier = user_context.get("subscription_tier", "starter")
|
||||
|
||||
logger.debug("Using subscription tier from JWT (no HTTP call)",
|
||||
tenant_id=tenant_id,
|
||||
current_tier=current_tier,
|
||||
minimum_tier=minimum_tier,
|
||||
allowed_tiers=allowed_tiers)
|
||||
|
||||
if current_tier not in [tier.lower() for tier in allowed_tiers]:
|
||||
tier_names = ', '.join(allowed_tiers)
|
||||
return {
|
||||
'allowed': False,
|
||||
'message': f'This feature requires a {tier_names} subscription plan',
|
||||
'current_tier': current_tier
|
||||
}
|
||||
await self._log_subscription_access(
|
||||
tenant_id,
|
||||
user_id,
|
||||
feature,
|
||||
current_tier,
|
||||
True,
|
||||
"jwt"
|
||||
)
|
||||
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (JWT subscription)',
|
||||
'current_tier': current_tier
|
||||
}
|
||||
|
||||
# Use unified HeaderManager for consistent header handling
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Extract user_id for logging (fallback path)
|
||||
user_id = header_manager.get_header_value(request, 'x-user-id', 'unknown')
|
||||
|
||||
# Call tenant service fast tier endpoint with caching (fallback for old tokens)
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=1.0, # Connection timeout - very short for cached endpoint
|
||||
read=5.0, # Read timeout - short for cached lookup
|
||||
write=1.0, # Write timeout
|
||||
pool=1.0 # Pool timeout
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||
# Use fast cached tier endpoint (new URL pattern)
|
||||
tier_response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
if tier_response.status_code != 200:
|
||||
logger.warning(
|
||||
"Failed to get subscription tier from cache",
|
||||
tenant_id=tenant_id,
|
||||
status_code=tier_response.status_code,
|
||||
response_text=tier_response.text
|
||||
)
|
||||
# Fail open for availability
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (validation service unavailable)',
|
||||
'current_tier': 'unknown'
|
||||
}
|
||||
|
||||
tier_data = tier_response.json()
|
||||
current_tier = tier_data.get('tier', 'starter').lower()
|
||||
|
||||
logger.debug("Subscription tier check (cached)",
|
||||
tenant_id=tenant_id,
|
||||
current_tier=current_tier,
|
||||
minimum_tier=minimum_tier,
|
||||
allowed_tiers=allowed_tiers,
|
||||
cached=tier_data.get('cached', False))
|
||||
|
||||
# Check if current tier is in allowed tiers
|
||||
if current_tier not in [tier.lower() for tier in allowed_tiers]:
|
||||
tier_names = ', '.join(allowed_tiers)
|
||||
await self._log_subscription_access(
|
||||
tenant_id,
|
||||
user_id,
|
||||
feature,
|
||||
current_tier,
|
||||
False,
|
||||
"jwt"
|
||||
)
|
||||
|
||||
return {
|
||||
'allowed': False,
|
||||
'message': f'This feature requires a {tier_names} subscription plan',
|
||||
'current_tier': current_tier
|
||||
}
|
||||
|
||||
# Tier check passed
|
||||
await self._log_subscription_access(
|
||||
tenant_id,
|
||||
user_id,
|
||||
feature,
|
||||
current_tier,
|
||||
True,
|
||||
"database"
|
||||
)
|
||||
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted',
|
||||
'current_tier': current_tier
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"Timeout validating subscription",
|
||||
tenant_id=tenant_id,
|
||||
feature=feature
|
||||
)
|
||||
# Fail open for availability (let service handle detailed check)
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (validation timeout)',
|
||||
'current_plan': 'unknown'
|
||||
}
|
||||
except httpx.RequestError as e:
|
||||
logger.error(
|
||||
"Request error validating subscription",
|
||||
tenant_id=tenant_id,
|
||||
feature=feature,
|
||||
error=str(e)
|
||||
)
|
||||
# Fail open for availability
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (validation service unavailable)',
|
||||
'current_plan': 'unknown'
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Subscription validation error",
|
||||
tenant_id=tenant_id,
|
||||
feature=feature,
|
||||
error=str(e)
|
||||
)
|
||||
# Fail open for availability (let service handle detailed check)
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (validation error)',
|
||||
'current_plan': 'unknown'
|
||||
}
|
||||
|
||||
async def _log_subscription_access(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
requested_feature: str,
|
||||
current_tier: str,
|
||||
access_granted: bool,
|
||||
source: str # "jwt" or "database"
|
||||
):
|
||||
"""
|
||||
Log all subscription-gated access attempts for audit and anomaly detection.
|
||||
"""
|
||||
logger.info(
|
||||
"Subscription access check",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
feature=requested_feature,
|
||||
tier=current_tier,
|
||||
granted=access_granted,
|
||||
source=source,
|
||||
timestamp=datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
|
||||
# For denied access, check for suspicious patterns
|
||||
if not access_granted:
|
||||
await self._check_for_abuse_patterns(tenant_id, user_id, requested_feature)
|
||||
|
||||
async def _check_for_abuse_patterns(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
feature: str
|
||||
):
|
||||
"""
|
||||
Detect potential abuse patterns like repeated premium feature access attempts.
|
||||
"""
|
||||
if not self.redis_client:
|
||||
return
|
||||
|
||||
# Track denied attempts in a sliding window
|
||||
key = f"subscription_denied:{tenant_id}:{user_id}:{feature}"
|
||||
|
||||
try:
|
||||
attempts = await self.redis_client.incr(key)
|
||||
if attempts == 1:
|
||||
await self.redis_client.expire(key, 3600) # 1 hour window
|
||||
|
||||
# Alert if too many denied attempts (potential bypass attempt)
|
||||
if attempts > 10:
|
||||
logger.warning(
|
||||
"SECURITY: Excessive premium feature access attempts detected",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
feature=feature,
|
||||
attempts=attempts,
|
||||
window="1 hour"
|
||||
)
|
||||
# Could trigger alert to security team here
|
||||
except Exception as e:
|
||||
logger.warning("Failed to track abuse patterns", error=str(e))
|
||||
|
||||
Reference in New Issue
Block a user