""" Subscription Middleware - Enforces subscription limits and feature access """ import re import json import structlog from fastapi import Request, Response, HTTPException from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware import httpx from typing import Dict, Any, Optional import asyncio from app.core.config import settings logger = structlog.get_logger() class SubscriptionMiddleware(BaseHTTPMiddleware): """Middleware to enforce subscription-based access control""" def __init__(self, app, tenant_service_url: str): super().__init__(app) self.tenant_service_url = tenant_service_url.rstrip('/') # Define route patterns that require subscription validation self.protected_routes = { # Analytics routes - require different levels based on actual app routes r'/api/v1/tenants/[^/]+/analytics/.*': { 'feature': 'analytics', 'minimum_level': 'basic' # Changed to basic to allow all tiers access to analytics }, r'/api/v1/tenants/[^/]+/forecasts/.*': { 'feature': 'analytics', 'minimum_level': 'basic' # Changed to basic to allow all tiers access to forecasting }, r'/api/v1/tenants/[^/]+/predictions/.*': { 'feature': 'analytics', 'minimum_level': 'basic' # Changed to basic to allow all tiers access to predictions }, # Training and AI models - Now available to all tiers r'/api/v1/tenants/[^/]+/training/.*': { 'feature': 'analytics', 'minimum_level': 'basic' # Changed to basic to allow all tiers access to training }, r'/api/v1/tenants/[^/]+/models/.*': { 'feature': 'analytics', 'minimum_level': 'basic' # Changed to basic to allow all tiers access to models }, # Advanced production features - Professional+ r'/api/v1/tenants/[^/]+/production/optimization/.*': { 'feature': 'analytics', 'minimum_level': 'basic' }, # Enterprise-only features r'/api/v1/tenants/[^/]+/statistics.*': { 'feature': 'analytics', 'minimum_level': 'basic' # Advanced stats for Enterprise only } } async def dispatch(self, request: Request, call_next): """Process the request and check subscription requirements""" # Skip subscription check for certain routes if self._should_skip_subscription_check(request): return await call_next(request) # Check if route requires subscription validation subscription_requirement = self._get_subscription_requirement(request.url.path) if not subscription_requirement: return await call_next(request) # Get tenant ID from request tenant_id = self._extract_tenant_id(request) if not tenant_id: return JSONResponse( status_code=400, content={ "error": "subscription_validation_failed", "message": "Tenant ID required for subscription validation", "code": "MISSING_TENANT_ID" } ) # Validate subscription validation_result = await self._validate_subscription( request, tenant_id, subscription_requirement['feature'], subscription_requirement['minimum_level'] ) if not validation_result['allowed']: return JSONResponse( status_code=403, content={ "error": "subscription_required", "message": validation_result['message'], "code": "SUBSCRIPTION_UPGRADE_REQUIRED", "details": { "required_feature": subscription_requirement['feature'], "required_level": subscription_requirement['minimum_level'], "current_plan": validation_result.get('current_plan', 'unknown'), "upgrade_url": "/app/settings/profile" } } ) # Subscription validation passed, continue with request response = await call_next(request) return response def _should_skip_subscription_check(self, request: Request) -> bool: """Check if subscription validation should be skipped""" path = request.url.path method = request.method # Skip for health checks, auth, and public routes skip_patterns = [ r'/health.*', r'/metrics.*', r'/api/v1/auth/.*', r'/api/v1/subscriptions/.*', # Subscription management itself r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info r'/docs.*', r'/openapi\.json' ] # Skip OPTIONS requests (CORS preflight) if method == "OPTIONS": return True for pattern in skip_patterns: if re.match(pattern, path): return True return False def _get_subscription_requirement(self, path: str) -> Optional[Dict[str, str]]: """Get subscription requirement for a given path""" for pattern, requirement in self.protected_routes.items(): if re.match(pattern, path): return requirement return None def _extract_tenant_id(self, request: Request) -> Optional[str]: """Extract tenant ID from request""" # Try to get from URL path first path_match = re.search(r'/api/v1/tenants/([^/]+)/', request.url.path) if path_match: return path_match.group(1) # Try to get from headers tenant_id = request.headers.get('x-tenant-id') if tenant_id: return tenant_id # Try to get from user state (set by auth middleware) if hasattr(request.state, 'user') and request.state.user: return request.state.user.get('tenant_id') return None async def _validate_subscription( self, request: Request, tenant_id: str, feature: str, minimum_level: str ) -> Dict[str, Any]: """Validate subscription feature access using the same pattern as other gateway services""" try: # Use the same authentication pattern as gateway routes headers = dict(request.headers) headers.pop("host", None) # Add user context headers if available (same as _proxy_request) if hasattr(request.state, 'user') and request.state.user: user = request.state.user headers["x-user-id"] = str(user.get('user_id', '')) headers["x-user-email"] = str(user.get('email', '')) headers["x-user-role"] = str(user.get('role', 'user')) headers["x-user-full-name"] = str(user.get('full_name', '')) headers["x-tenant-id"] = str(user.get('tenant_id', '')) # Call tenant service to check subscription with gateway-appropriate timeout timeout_config = httpx.Timeout( connect=2.0, # Connection timeout - short for gateway read=10.0, # Read timeout write=2.0, # Write timeout pool=2.0 # Pool timeout ) async with httpx.AsyncClient(timeout=timeout_config) as client: # Check feature access feature_response = await client.get( f"{settings.TENANT_SERVICE_URL}/api/v1/subscriptions/{tenant_id}/features/{feature}", headers=headers ) if feature_response.status_code != 200: logger.warning( "Failed to check feature access", tenant_id=tenant_id, feature=feature, status_code=feature_response.status_code, response_text=feature_response.text, url=f"{settings.TENANT_SERVICE_URL}/api/v1/subscriptions/{tenant_id}/features/{feature}" ) # Fail open for availability (let service handle detailed check if needed) return { 'allowed': True, 'message': 'Access granted (validation service unavailable)', 'current_plan': 'unknown' } feature_data = feature_response.json() logger.info("Feature check response", tenant_id=tenant_id, feature=feature, response=feature_data) if not feature_data.get('has_feature'): return { 'allowed': False, 'message': f'Feature "{feature}" not available in your current plan', 'current_plan': feature_data.get('plan', 'unknown') } # Check feature level if it's analytics if feature == 'analytics': feature_level = feature_data.get('feature_value', 'basic') if not self._check_analytics_level(feature_level, minimum_level): return { 'allowed': False, 'message': f'Analytics level "{minimum_level}" required. Current level: "{feature_level}"', 'current_plan': feature_data.get('plan', 'unknown') } return { 'allowed': True, 'message': 'Access granted', 'current_plan': feature_data.get('plan', 'unknown') } except asyncio.TimeoutError: logger.error( "Timeout validating subscription", tenant_id=tenant_id, feature=feature ) # Fail open for availability (let service handle detailed check) return { 'allowed': True, 'message': 'Access granted (validation timeout)', 'current_plan': 'unknown' } except httpx.RequestError as e: logger.error( "Request error validating subscription", tenant_id=tenant_id, feature=feature, error=str(e) ) # Fail open for availability return { 'allowed': True, 'message': 'Access granted (validation service unavailable)', 'current_plan': 'unknown' } except Exception as e: logger.error( "Subscription validation error", tenant_id=tenant_id, feature=feature, error=str(e) ) # Fail open for availability (let service handle detailed check) return { 'allowed': True, 'message': 'Access granted (validation error)', 'current_plan': 'unknown' } def _check_analytics_level(self, current_level: str, required_level: str) -> bool: """Check if current analytics level meets the requirement""" level_hierarchy = { 'basic': 1, 'advanced': 2, 'predictive': 3 } current_rank = level_hierarchy.get(current_level, 0) required_rank = level_hierarchy.get(required_level, 0) return current_rank >= required_rank