""" Subscription Middleware - Enforces subscription limits and feature access Updated to support standardized URL structure with tier-based access control """ import re import json import structlog from fastapi import Request, Response, HTTPException from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware import httpx from typing import Dict, Any, Optional, List import asyncio from app.core.config import settings logger = structlog.get_logger() class SubscriptionMiddleware(BaseHTTPMiddleware): """ Middleware to enforce subscription-based access control Supports standardized URL structure: - Base routes (/api/v1/tenants/{tenant_id}/{service}/{resource}): ALL tiers - Dashboard routes (/api/v1/tenants/{tenant_id}/{service}/dashboard/*): ALL tiers - Analytics routes (/api/v1/tenants/{tenant_id}/{service}/analytics/*): PROFESSIONAL+ - Operations routes (/api/v1/tenants/{tenant_id}/{service}/operations/*): ALL tiers (role-based) """ def __init__(self, app, tenant_service_url: str): super().__init__(app) self.tenant_service_url = tenant_service_url.rstrip('/') # Define route patterns that require subscription validation # Using new standardized URL structure self.protected_routes = { # ===== ANALYTICS ROUTES - PROFESSIONAL/ENTERPRISE ONLY ===== # Any service analytics endpoint r'^/api/v1/tenants/[^/]+/[^/]+/analytics/.*': { 'feature': 'analytics', 'minimum_tier': 'professional', 'allowed_tiers': ['professional', 'enterprise'], 'description': 'Analytics features (Professional/Enterprise only)' }, # ===== TRAINING SERVICE - ALL TIERS ===== r'^/api/v1/tenants/[^/]+/training/.*': { 'feature': 'ml_training', 'minimum_tier': 'basic', 'allowed_tiers': ['basic', 'professional', 'enterprise'], 'description': 'Machine learning model training (Available for all tiers)' }, # ===== ADVANCED FEATURES - PROFESSIONAL/ENTERPRISE ===== # Advanced reporting and exports r'^/api/v1/tenants/[^/]+/[^/]+/export/advanced.*': { 'feature': 'advanced_exports', 'minimum_tier': 'professional', 'allowed_tiers': ['professional', 'enterprise'], 'description': 'Advanced export formats (Professional/Enterprise only)' }, # Bulk operations r'^/api/v1/tenants/[^/]+/[^/]+/bulk/.*': { 'feature': 'bulk_operations', 'minimum_tier': 'professional', 'allowed_tiers': ['professional', 'enterprise'], 'description': 'Bulk operations (Professional/Enterprise only)' }, } # Routes that are explicitly allowed for all tiers (no check needed) self.public_tier_routes = [ # Base CRUD operations - ALL TIERS r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/?$', r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/[^/]+/?$', # Dashboard routes - ALL TIERS r'^/api/v1/tenants/[^/]+/[^/]+/dashboard/.*', # Operations routes - ALL TIERS (role-based control applies) r'^/api/v1/tenants/[^/]+/[^/]+/operations/.*', ] async def dispatch(self, request: Request, call_next): """Process the request and check subscription requirements""" # Skip subscription check for certain routes if self._should_skip_subscription_check(request): return await call_next(request) # Check if route is explicitly allowed for all tiers if self._is_public_tier_route(request.url.path): return await call_next(request) # Check if route requires subscription validation subscription_requirement = self._get_subscription_requirement(request.url.path) if not subscription_requirement: return await call_next(request) # Get tenant ID from request tenant_id = self._extract_tenant_id(request) if not tenant_id: return JSONResponse( status_code=400, content={ "error": "subscription_validation_failed", "message": "Tenant ID required for subscription validation", "code": "MISSING_TENANT_ID" } ) # Validate subscription with new tier-based system validation_result = await self._validate_subscription_tier( request, tenant_id, subscription_requirement.get('feature'), subscription_requirement.get('minimum_tier'), subscription_requirement.get('allowed_tiers', []) ) if not validation_result['allowed']: return JSONResponse( status_code=402, # Payment Required for tier limitations content={ "error": "subscription_tier_insufficient", "message": validation_result['message'], "code": "SUBSCRIPTION_UPGRADE_REQUIRED", "details": { "required_feature": subscription_requirement.get('feature'), "minimum_tier": subscription_requirement.get('minimum_tier'), "allowed_tiers": subscription_requirement.get('allowed_tiers', []), "current_tier": validation_result.get('current_tier', 'unknown'), "description": subscription_requirement.get('description', ''), "upgrade_url": "/app/settings/profile" } } ) # Subscription validation passed, continue with request response = await call_next(request) return response def _is_public_tier_route(self, path: str) -> bool: """ Check if route is explicitly allowed for all subscription tiers Args: path: Request path Returns: True if route is allowed for all tiers """ for pattern in self.public_tier_routes: if re.match(pattern, path): logger.debug("Route allowed for all tiers", path=path, pattern=pattern) return True return False def _should_skip_subscription_check(self, request: Request) -> bool: """Check if subscription validation should be skipped""" path = request.url.path method = request.method # Skip for health checks, auth, and public routes skip_patterns = [ r'/health.*', r'/metrics.*', r'/api/v1/auth/.*', r'/api/v1/subscriptions/.*', # Subscription management itself r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info r'/docs.*', r'/openapi\.json' ] # Skip OPTIONS requests (CORS preflight) if method == "OPTIONS": return True for pattern in skip_patterns: if re.match(pattern, path): return True return False def _get_subscription_requirement(self, path: str) -> Optional[Dict[str, str]]: """Get subscription requirement for a given path""" for pattern, requirement in self.protected_routes.items(): if re.match(pattern, path): return requirement return None def _extract_tenant_id(self, request: Request) -> Optional[str]: """Extract tenant ID from request""" # Try to get from URL path first path_match = re.search(r'/api/v1/tenants/([^/]+)/', request.url.path) if path_match: return path_match.group(1) # Try to get from headers tenant_id = request.headers.get('x-tenant-id') if tenant_id: return tenant_id # Try to get from user state (set by auth middleware) if hasattr(request.state, 'user') and request.state.user: return request.state.user.get('tenant_id') return None async def _validate_subscription_tier( self, request: Request, tenant_id: str, feature: Optional[str], minimum_tier: str, allowed_tiers: List[str] ) -> Dict[str, Any]: """ Validate subscription tier access using tenant service Args: request: FastAPI request tenant_id: Tenant ID feature: Feature name (optional, for additional checks) minimum_tier: Minimum required subscription tier allowed_tiers: List of allowed subscription tiers Returns: Dict with 'allowed' boolean and additional metadata """ try: # Use the same authentication pattern as gateway routes headers = dict(request.headers) headers.pop("host", None) # Add user context headers if available if hasattr(request.state, 'user') and request.state.user: user = request.state.user headers["x-user-id"] = str(user.get('user_id', '')) headers["x-user-email"] = str(user.get('email', '')) headers["x-user-role"] = str(user.get('role', 'user')) headers["x-user-full-name"] = str(user.get('full_name', '')) headers["x-tenant-id"] = str(user.get('tenant_id', '')) # Call tenant service to get subscription tier with gateway-appropriate timeout timeout_config = httpx.Timeout( connect=2.0, # Connection timeout - short for gateway read=10.0, # Read timeout write=2.0, # Write timeout pool=2.0 # Pool timeout ) async with httpx.AsyncClient(timeout=timeout_config) as client: # Get tenant subscription information tenant_response = await client.get( f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}", headers=headers ) if tenant_response.status_code != 200: logger.warning( "Failed to get tenant subscription", tenant_id=tenant_id, status_code=tenant_response.status_code, response_text=tenant_response.text ) # Fail open for availability return { 'allowed': True, 'message': 'Access granted (validation service unavailable)', 'current_tier': 'unknown' } tenant_data = tenant_response.json() current_tier = tenant_data.get('subscription_tier', 'starter').lower() logger.debug("Subscription tier check", tenant_id=tenant_id, current_tier=current_tier, minimum_tier=minimum_tier, allowed_tiers=allowed_tiers) # Check if current tier is in allowed tiers if current_tier not in [tier.lower() for tier in allowed_tiers]: tier_names = ', '.join(allowed_tiers) return { 'allowed': False, 'message': f'This feature requires a {tier_names} subscription plan', 'current_tier': current_tier } # Tier check passed return { 'allowed': True, 'message': 'Access granted', 'current_tier': current_tier } except asyncio.TimeoutError: logger.error( "Timeout validating subscription", tenant_id=tenant_id, feature=feature ) # Fail open for availability (let service handle detailed check) return { 'allowed': True, 'message': 'Access granted (validation timeout)', 'current_plan': 'unknown' } except httpx.RequestError as e: logger.error( "Request error validating subscription", tenant_id=tenant_id, feature=feature, error=str(e) ) # Fail open for availability return { 'allowed': True, 'message': 'Access granted (validation service unavailable)', 'current_plan': 'unknown' } except Exception as e: logger.error( "Subscription validation error", tenant_id=tenant_id, feature=feature, error=str(e) ) # Fail open for availability (let service handle detailed check) return { 'allowed': True, 'message': 'Access granted (validation error)', 'current_plan': 'unknown' }