Files
bakery-ia/gateway/app/middleware/subscription.py
Urtzi Alfaro bf1db7cb9e New token arch
2026-01-10 21:45:37 +01:00

467 lines
18 KiB
Python

"""
Subscription Middleware - Enforces subscription limits and feature access
Updated to support standardized URL structure with tier-based access control
"""
import re
import json
import structlog
from fastapi import Request, Response, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import httpx
from typing import Dict, Any, Optional, List
import asyncio
from datetime import datetime, timezone
from app.core.config import settings
from app.utils.subscription_error_responses import create_upgrade_required_response
logger = structlog.get_logger()
class SubscriptionMiddleware(BaseHTTPMiddleware):
"""
Middleware to enforce subscription-based access control
Supports standardized URL structure:
- Base routes (/api/v1/tenants/{tenant_id}/{service}/{resource}): ALL tiers
- Dashboard routes (/api/v1/tenants/{tenant_id}/{service}/dashboard/*): ALL tiers
- Analytics routes (/api/v1/tenants/{tenant_id}/{service}/analytics/*): PROFESSIONAL+
- Operations routes (/api/v1/tenants/{tenant_id}/{service}/operations/*): ALL tiers (role-based)
"""
def __init__(self, app, tenant_service_url: str, redis_client=None):
super().__init__(app)
self.tenant_service_url = tenant_service_url.rstrip('/')
self.redis_client = redis_client # Optional Redis client for abuse detection
# Define route patterns that require subscription validation
# Using new standardized URL structure
self.protected_routes = {
# ===== ANALYTICS ROUTES - PROFESSIONAL/ENTERPRISE ONLY =====
# Any service analytics endpoint
r'^/api/v1/tenants/[^/]+/[^/]+/analytics/.*': {
'feature': 'analytics',
'minimum_tier': 'professional',
'allowed_tiers': ['professional', 'enterprise'],
'description': 'Analytics features (Professional/Enterprise only)'
},
# ===== TRAINING SERVICE - ALL TIERS =====
r'^/api/v1/tenants/[^/]+/training/.*': {
'feature': 'ml_training',
'minimum_tier': 'basic',
'allowed_tiers': ['basic', 'professional', 'enterprise'],
'description': 'Machine learning model training (Available for all tiers)'
},
# ===== ADVANCED FEATURES - PROFESSIONAL/ENTERPRISE =====
# Advanced reporting and exports
r'^/api/v1/tenants/[^/]+/[^/]+/export/advanced.*': {
'feature': 'advanced_exports',
'minimum_tier': 'professional',
'allowed_tiers': ['professional', 'enterprise'],
'description': 'Advanced export formats (Professional/Enterprise only)'
},
# Bulk operations
r'^/api/v1/tenants/[^/]+/[^/]+/bulk/.*': {
'feature': 'bulk_operations',
'minimum_tier': 'professional',
'allowed_tiers': ['professional', 'enterprise'],
'description': 'Bulk operations (Professional/Enterprise only)'
},
}
# Routes that are explicitly allowed for all tiers (no check needed)
self.public_tier_routes = [
# Base CRUD operations - ALL TIERS
r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/?$',
r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/[^/]+/?$',
# Dashboard routes - ALL TIERS
r'^/api/v1/tenants/[^/]+/[^/]+/dashboard/.*',
# Operations routes - ALL TIERS (role-based control applies)
r'^/api/v1/tenants/[^/]+/[^/]+/operations/.*',
]
async def dispatch(self, request: Request, call_next):
"""Process the request and check subscription requirements"""
# Skip subscription check for certain routes
if self._should_skip_subscription_check(request):
return await call_next(request)
# Check if route is explicitly allowed for all tiers
if self._is_public_tier_route(request.url.path):
return await call_next(request)
# Check if route requires subscription validation
subscription_requirement = self._get_subscription_requirement(request.url.path)
if not subscription_requirement:
return await call_next(request)
# Get tenant ID from request
tenant_id = self._extract_tenant_id(request)
if not tenant_id:
return JSONResponse(
status_code=400,
content={
"error": "subscription_validation_failed",
"message": "Tenant ID required for subscription validation",
"code": "MISSING_TENANT_ID"
}
)
# Validate subscription with new tier-based system
validation_result = await self._validate_subscription_tier(
request,
tenant_id,
subscription_requirement.get('feature'),
subscription_requirement.get('minimum_tier'),
subscription_requirement.get('allowed_tiers', [])
)
if not validation_result['allowed']:
# Use enhanced error response with conversion optimization
feature = subscription_requirement.get('feature')
current_tier = validation_result.get('current_tier', 'unknown')
required_tier = subscription_requirement.get('minimum_tier')
allowed_tiers = subscription_requirement.get('allowed_tiers', [])
# Create conversion-optimized error response
enhanced_response = create_upgrade_required_response(
feature=feature,
current_tier=current_tier,
required_tier=required_tier,
allowed_tiers=allowed_tiers,
custom_message=validation_result.get('message')
)
return JSONResponse(
status_code=enhanced_response.status_code,
content=enhanced_response.dict()
)
# Subscription validation passed, continue with request
response = await call_next(request)
return response
def _is_public_tier_route(self, path: str) -> bool:
"""
Check if route is explicitly allowed for all subscription tiers
Args:
path: Request path
Returns:
True if route is allowed for all tiers
"""
for pattern in self.public_tier_routes:
if re.match(pattern, path):
logger.debug("Route allowed for all tiers", path=path, pattern=pattern)
return True
return False
def _should_skip_subscription_check(self, request: Request) -> bool:
"""Check if subscription validation should be skipped"""
path = request.url.path
method = request.method
# Skip for health checks, auth, and public routes
skip_patterns = [
r'/health.*',
r'/metrics.*',
r'/api/v1/auth/.*',
r'/api/v1/subscriptions/.*', # Subscription management itself
r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info
r'/docs.*',
r'/openapi\.json'
]
# Skip OPTIONS requests (CORS preflight)
if method == "OPTIONS":
return True
for pattern in skip_patterns:
if re.match(pattern, path):
return True
return False
def _get_subscription_requirement(self, path: str) -> Optional[Dict[str, str]]:
"""Get subscription requirement for a given path"""
for pattern, requirement in self.protected_routes.items():
if re.match(pattern, path):
return requirement
return None
def _extract_tenant_id(self, request: Request) -> Optional[str]:
"""Extract tenant ID from request"""
# Try to get from URL path first
path_match = re.search(r'/api/v1/tenants/([^/]+)/', request.url.path)
if path_match:
return path_match.group(1)
# Try to get from headers
tenant_id = request.headers.get('x-tenant-id')
if tenant_id:
return tenant_id
# Try to get from user state (set by auth middleware)
if hasattr(request.state, 'user') and request.state.user:
return request.state.user.get('tenant_id')
return None
async def _validate_subscription_tier(
self,
request: Request,
tenant_id: str,
feature: Optional[str],
minimum_tier: str,
allowed_tiers: List[str]
) -> Dict[str, Any]:
"""
Validate subscription tier access using cached subscription lookup
Args:
request: FastAPI request
tenant_id: Tenant ID
feature: Feature name (optional, for additional checks)
minimum_tier: Minimum required subscription tier
allowed_tiers: List of allowed subscription tiers
Returns:
Dict with 'allowed' boolean and additional metadata
"""
try:
# Check if JWT already has subscription
if hasattr(request.state, 'user') and request.state.user:
user_context = request.state.user
user_id = user_context.get('user_id', 'unknown')
if user_context.get("subscription_from_jwt"):
# Use JWT data directly - NO HTTP CALL!
current_tier = user_context.get("subscription_tier", "starter")
logger.debug("Using subscription tier from JWT (no HTTP call)",
tenant_id=tenant_id,
current_tier=current_tier,
minimum_tier=minimum_tier,
allowed_tiers=allowed_tiers)
if current_tier not in [tier.lower() for tier in allowed_tiers]:
tier_names = ', '.join(allowed_tiers)
return {
'allowed': False,
'message': f'This feature requires a {tier_names} subscription plan',
'current_tier': current_tier
}
await self._log_subscription_access(
tenant_id,
user_id,
feature,
current_tier,
True,
"jwt"
)
return {
'allowed': True,
'message': 'Access granted (JWT subscription)',
'current_tier': current_tier
}
# Use the same authentication pattern as gateway routes for fallback
headers = dict(request.headers)
headers.pop("host", None)
# Extract user_id for logging (fallback path)
user_id = 'unknown'
# Add user context headers if available
if hasattr(request.state, 'user') and request.state.user:
user = request.state.user
user_id = str(user.get('user_id', 'unknown'))
headers["x-user-id"] = user_id
headers["x-user-email"] = str(user.get('email', ''))
headers["x-user-role"] = str(user.get('role', 'user'))
headers["x-user-full-name"] = str(user.get('full_name', ''))
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
# Call tenant service fast tier endpoint with caching (fallback for old tokens)
timeout_config = httpx.Timeout(
connect=1.0, # Connection timeout - very short for cached endpoint
read=5.0, # Read timeout - short for cached lookup
write=1.0, # Write timeout
pool=1.0 # Pool timeout
)
async with httpx.AsyncClient(timeout=timeout_config) as client:
# Use fast cached tier endpoint
tier_response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/subscriptions/{tenant_id}/tier",
headers=headers
)
if tier_response.status_code != 200:
logger.warning(
"Failed to get subscription tier from cache",
tenant_id=tenant_id,
status_code=tier_response.status_code,
response_text=tier_response.text
)
# Fail open for availability
return {
'allowed': True,
'message': 'Access granted (validation service unavailable)',
'current_tier': 'unknown'
}
tier_data = tier_response.json()
current_tier = tier_data.get('tier', 'starter').lower()
logger.debug("Subscription tier check (cached)",
tenant_id=tenant_id,
current_tier=current_tier,
minimum_tier=minimum_tier,
allowed_tiers=allowed_tiers,
cached=tier_data.get('cached', False))
# Check if current tier is in allowed tiers
if current_tier not in [tier.lower() for tier in allowed_tiers]:
tier_names = ', '.join(allowed_tiers)
await self._log_subscription_access(
tenant_id,
user_id,
feature,
current_tier,
False,
"jwt"
)
return {
'allowed': False,
'message': f'This feature requires a {tier_names} subscription plan',
'current_tier': current_tier
}
# Tier check passed
await self._log_subscription_access(
tenant_id,
user_id,
feature,
current_tier,
True,
"database"
)
return {
'allowed': True,
'message': 'Access granted',
'current_tier': current_tier
}
except asyncio.TimeoutError:
logger.error(
"Timeout validating subscription",
tenant_id=tenant_id,
feature=feature
)
# Fail open for availability (let service handle detailed check)
return {
'allowed': True,
'message': 'Access granted (validation timeout)',
'current_plan': 'unknown'
}
except httpx.RequestError as e:
logger.error(
"Request error validating subscription",
tenant_id=tenant_id,
feature=feature,
error=str(e)
)
# Fail open for availability
return {
'allowed': True,
'message': 'Access granted (validation service unavailable)',
'current_plan': 'unknown'
}
except Exception as e:
logger.error(
"Subscription validation error",
tenant_id=tenant_id,
feature=feature,
error=str(e)
)
# Fail open for availability (let service handle detailed check)
return {
'allowed': True,
'message': 'Access granted (validation error)',
'current_plan': 'unknown'
}
async def _log_subscription_access(
self,
tenant_id: str,
user_id: str,
requested_feature: str,
current_tier: str,
access_granted: bool,
source: str # "jwt" or "database"
):
"""
Log all subscription-gated access attempts for audit and anomaly detection.
"""
logger.info(
"Subscription access check",
tenant_id=tenant_id,
user_id=user_id,
feature=requested_feature,
tier=current_tier,
granted=access_granted,
source=source,
timestamp=datetime.now(timezone.utc).isoformat()
)
# For denied access, check for suspicious patterns
if not access_granted:
await self._check_for_abuse_patterns(tenant_id, user_id, requested_feature)
async def _check_for_abuse_patterns(
self,
tenant_id: str,
user_id: str,
feature: str
):
"""
Detect potential abuse patterns like repeated premium feature access attempts.
"""
if not self.redis_client:
return
# Track denied attempts in a sliding window
key = f"subscription_denied:{tenant_id}:{user_id}:{feature}"
try:
attempts = await self.redis_client.incr(key)
if attempts == 1:
await self.redis_client.expire(key, 3600) # 1 hour window
# Alert if too many denied attempts (potential bypass attempt)
if attempts > 10:
logger.warning(
"SECURITY: Excessive premium feature access attempts detected",
tenant_id=tenant_id,
user_id=user_id,
feature=feature,
attempts=attempts,
window="1 hour"
)
# Could trigger alert to security team here
except Exception as e:
logger.warning("Failed to track abuse patterns", error=str(e))