Files
bakery-ia/gateway/app/middleware/subscription.py
2025-09-21 15:51:58 +02:00

298 lines
12 KiB
Python

"""
Subscription Middleware - Enforces subscription limits and feature access
"""
import re
import json
import structlog
from fastapi import Request, Response, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import httpx
from typing import Dict, Any, Optional
import asyncio
from app.core.config import settings
logger = structlog.get_logger()
class SubscriptionMiddleware(BaseHTTPMiddleware):
"""Middleware to enforce subscription-based access control"""
def __init__(self, app, tenant_service_url: str):
super().__init__(app)
self.tenant_service_url = tenant_service_url.rstrip('/')
# Define route patterns that require subscription validation
self.protected_routes = {
# Analytics routes - require different levels based on actual app routes
r'/api/v1/tenants/[^/]+/analytics/.*': {
'feature': 'analytics',
'minimum_level': 'basic' # Changed to basic to allow all tiers access to analytics
},
r'/api/v1/tenants/[^/]+/forecasts/.*': {
'feature': 'analytics',
'minimum_level': 'basic' # Changed to basic to allow all tiers access to forecasting
},
r'/api/v1/tenants/[^/]+/predictions/.*': {
'feature': 'analytics',
'minimum_level': 'basic' # Changed to basic to allow all tiers access to predictions
},
# Training and AI models - Now available to all tiers
r'/api/v1/tenants/[^/]+/training/.*': {
'feature': 'analytics',
'minimum_level': 'basic' # Changed to basic to allow all tiers access to training
},
r'/api/v1/tenants/[^/]+/models/.*': {
'feature': 'analytics',
'minimum_level': 'basic' # Changed to basic to allow all tiers access to models
},
# Advanced production features - Professional+
r'/api/v1/tenants/[^/]+/production/optimization/.*': {
'feature': 'analytics',
'minimum_level': 'basic'
},
# Enterprise-only features
r'/api/v1/tenants/[^/]+/statistics.*': {
'feature': 'analytics',
'minimum_level': 'basic' # Advanced stats for Enterprise only
}
}
async def dispatch(self, request: Request, call_next):
"""Process the request and check subscription requirements"""
# Skip subscription check for certain routes
if self._should_skip_subscription_check(request):
return await call_next(request)
# Check if route requires subscription validation
subscription_requirement = self._get_subscription_requirement(request.url.path)
if not subscription_requirement:
return await call_next(request)
# Get tenant ID from request
tenant_id = self._extract_tenant_id(request)
if not tenant_id:
return JSONResponse(
status_code=400,
content={
"error": "subscription_validation_failed",
"message": "Tenant ID required for subscription validation",
"code": "MISSING_TENANT_ID"
}
)
# Validate subscription
validation_result = await self._validate_subscription(
request,
tenant_id,
subscription_requirement['feature'],
subscription_requirement['minimum_level']
)
if not validation_result['allowed']:
return JSONResponse(
status_code=403,
content={
"error": "subscription_required",
"message": validation_result['message'],
"code": "SUBSCRIPTION_UPGRADE_REQUIRED",
"details": {
"required_feature": subscription_requirement['feature'],
"required_level": subscription_requirement['minimum_level'],
"current_plan": validation_result.get('current_plan', 'unknown'),
"upgrade_url": "/app/settings/profile"
}
}
)
# Subscription validation passed, continue with request
response = await call_next(request)
return response
def _should_skip_subscription_check(self, request: Request) -> bool:
"""Check if subscription validation should be skipped"""
path = request.url.path
method = request.method
# Skip for health checks, auth, and public routes
skip_patterns = [
r'/health.*',
r'/metrics.*',
r'/api/v1/auth/.*',
r'/api/v1/subscriptions/.*', # Subscription management itself
r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info
r'/docs.*',
r'/openapi\.json'
]
# Skip OPTIONS requests (CORS preflight)
if method == "OPTIONS":
return True
for pattern in skip_patterns:
if re.match(pattern, path):
return True
return False
def _get_subscription_requirement(self, path: str) -> Optional[Dict[str, str]]:
"""Get subscription requirement for a given path"""
for pattern, requirement in self.protected_routes.items():
if re.match(pattern, path):
return requirement
return None
def _extract_tenant_id(self, request: Request) -> Optional[str]:
"""Extract tenant ID from request"""
# Try to get from URL path first
path_match = re.search(r'/api/v1/tenants/([^/]+)/', request.url.path)
if path_match:
return path_match.group(1)
# Try to get from headers
tenant_id = request.headers.get('x-tenant-id')
if tenant_id:
return tenant_id
# Try to get from user state (set by auth middleware)
if hasattr(request.state, 'user') and request.state.user:
return request.state.user.get('tenant_id')
return None
async def _validate_subscription(
self,
request: Request,
tenant_id: str,
feature: str,
minimum_level: str
) -> Dict[str, Any]:
"""Validate subscription feature access using the same pattern as other gateway services"""
try:
# Use the same authentication pattern as gateway routes
headers = dict(request.headers)
headers.pop("host", None)
# Add user context headers if available (same as _proxy_request)
if hasattr(request.state, 'user') and request.state.user:
user = request.state.user
headers["x-user-id"] = str(user.get('user_id', ''))
headers["x-user-email"] = str(user.get('email', ''))
headers["x-user-role"] = str(user.get('role', 'user'))
headers["x-user-full-name"] = str(user.get('full_name', ''))
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
# Call tenant service to check subscription with gateway-appropriate timeout
timeout_config = httpx.Timeout(
connect=2.0, # Connection timeout - short for gateway
read=10.0, # Read timeout
write=2.0, # Write timeout
pool=2.0 # Pool timeout
)
async with httpx.AsyncClient(timeout=timeout_config) as client:
# Check feature access
feature_response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/subscriptions/{tenant_id}/features/{feature}",
headers=headers
)
if feature_response.status_code != 200:
logger.warning(
"Failed to check feature access",
tenant_id=tenant_id,
feature=feature,
status_code=feature_response.status_code,
response_text=feature_response.text,
url=f"{settings.TENANT_SERVICE_URL}/api/v1/subscriptions/{tenant_id}/features/{feature}"
)
# Fail open for availability (let service handle detailed check if needed)
return {
'allowed': True,
'message': 'Access granted (validation service unavailable)',
'current_plan': 'unknown'
}
feature_data = feature_response.json()
logger.info("Feature check response",
tenant_id=tenant_id,
feature=feature,
response=feature_data)
if not feature_data.get('has_feature'):
return {
'allowed': False,
'message': f'Feature "{feature}" not available in your current plan',
'current_plan': feature_data.get('plan', 'unknown')
}
# Check feature level if it's analytics
if feature == 'analytics':
feature_level = feature_data.get('feature_value', 'basic')
if not self._check_analytics_level(feature_level, minimum_level):
return {
'allowed': False,
'message': f'Analytics level "{minimum_level}" required. Current level: "{feature_level}"',
'current_plan': feature_data.get('plan', 'unknown')
}
return {
'allowed': True,
'message': 'Access granted',
'current_plan': feature_data.get('plan', 'unknown')
}
except asyncio.TimeoutError:
logger.error(
"Timeout validating subscription",
tenant_id=tenant_id,
feature=feature
)
# Fail open for availability (let service handle detailed check)
return {
'allowed': True,
'message': 'Access granted (validation timeout)',
'current_plan': 'unknown'
}
except httpx.RequestError as e:
logger.error(
"Request error validating subscription",
tenant_id=tenant_id,
feature=feature,
error=str(e)
)
# Fail open for availability
return {
'allowed': True,
'message': 'Access granted (validation service unavailable)',
'current_plan': 'unknown'
}
except Exception as e:
logger.error(
"Subscription validation error",
tenant_id=tenant_id,
feature=feature,
error=str(e)
)
# Fail open for availability (let service handle detailed check)
return {
'allowed': True,
'message': 'Access granted (validation error)',
'current_plan': 'unknown'
}
def _check_analytics_level(self, current_level: str, required_level: str) -> bool:
"""Check if current analytics level meets the requirement"""
level_hierarchy = {
'basic': 1,
'advanced': 2,
'predictive': 3
}
current_rank = level_hierarchy.get(current_level, 0)
required_rank = level_hierarchy.get(required_level, 0)
return current_rank >= required_rank