""" Subscription Middleware - Enforces subscription limits and feature access Updated to support standardized URL structure with tier-based access control """ import re import json import structlog from fastapi import Request, Response, HTTPException from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware import httpx from typing import Dict, Any, Optional, List import asyncio from datetime import datetime, timezone from app.core.config import settings from app.core.header_manager import header_manager from app.utils.subscription_error_responses import create_upgrade_required_response logger = structlog.get_logger() class SubscriptionMiddleware(BaseHTTPMiddleware): """ Middleware to enforce subscription-based access control Supports standardized URL structure: - Base routes (/api/v1/tenants/{tenant_id}/{service}/{resource}): ALL tiers - Dashboard routes (/api/v1/tenants/{tenant_id}/{service}/dashboard/*): ALL tiers - Analytics routes (/api/v1/tenants/{tenant_id}/{service}/analytics/*): PROFESSIONAL+ - Operations routes (/api/v1/tenants/{tenant_id}/{service}/operations/*): ALL tiers (role-based) """ def __init__(self, app, tenant_service_url: str, redis_client=None): super().__init__(app) self.tenant_service_url = tenant_service_url.rstrip('/') self.redis_client = redis_client # Optional Redis client for abuse detection # Define route patterns that require subscription validation # Using new standardized URL structure self.protected_routes = { # ===== ANALYTICS ROUTES - PROFESSIONAL/ENTERPRISE ONLY ===== # Any service analytics endpoint r'^/api/v1/tenants/[^/]+/[^/]+/analytics/.*': { 'feature': 'analytics', 'minimum_tier': 'professional', 'allowed_tiers': ['professional', 'enterprise'], 'description': 'Analytics features (Professional/Enterprise only)' }, # ===== TRAINING SERVICE - ALL TIERS ===== r'^/api/v1/tenants/[^/]+/training/.*': { 'feature': 'ml_training', 'minimum_tier': 'basic', 'allowed_tiers': ['basic', 'professional', 'enterprise'], 'description': 'Machine learning model training (Available for all tiers)' }, # ===== ADVANCED FEATURES - PROFESSIONAL/ENTERPRISE ===== # Advanced reporting and exports r'^/api/v1/tenants/[^/]+/[^/]+/export/advanced.*': { 'feature': 'advanced_exports', 'minimum_tier': 'professional', 'allowed_tiers': ['professional', 'enterprise'], 'description': 'Advanced export formats (Professional/Enterprise only)' }, # Bulk operations r'^/api/v1/tenants/[^/]+/[^/]+/bulk/.*': { 'feature': 'bulk_operations', 'minimum_tier': 'professional', 'allowed_tiers': ['professional', 'enterprise'], 'description': 'Bulk operations (Professional/Enterprise only)' }, } # Routes that are explicitly allowed for all tiers (no check needed) self.public_tier_routes = [ # Base CRUD operations - ALL TIERS r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/?$', r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/[^/]+/?$', # Dashboard routes - ALL TIERS r'^/api/v1/tenants/[^/]+/[^/]+/dashboard/.*', # Operations routes - ALL TIERS (role-based control applies) r'^/api/v1/tenants/[^/]+/[^/]+/operations/.*', ] async def dispatch(self, request: Request, call_next): """Process the request and check subscription requirements""" # Skip subscription check for certain routes if self._should_skip_subscription_check(request): return await call_next(request) # Check if route is explicitly allowed for all tiers if self._is_public_tier_route(request.url.path): return await call_next(request) # Check if route requires subscription validation subscription_requirement = self._get_subscription_requirement(request.url.path) if not subscription_requirement: return await call_next(request) # Get tenant ID from request tenant_id = self._extract_tenant_id(request) if not tenant_id: return JSONResponse( status_code=400, content={ "error": "subscription_validation_failed", "message": "Tenant ID required for subscription validation", "code": "MISSING_TENANT_ID" } ) # Validate subscription with new tier-based system validation_result = await self._validate_subscription_tier( request, tenant_id, subscription_requirement.get('feature'), subscription_requirement.get('minimum_tier'), subscription_requirement.get('allowed_tiers', []) ) if not validation_result['allowed']: # Use enhanced error response with conversion optimization feature = subscription_requirement.get('feature') current_tier = validation_result.get('current_tier', 'unknown') required_tier = subscription_requirement.get('minimum_tier') allowed_tiers = subscription_requirement.get('allowed_tiers', []) # Create conversion-optimized error response enhanced_response = create_upgrade_required_response( feature=feature, current_tier=current_tier, required_tier=required_tier, allowed_tiers=allowed_tiers, custom_message=validation_result.get('message') ) return JSONResponse( status_code=enhanced_response.status_code, content=enhanced_response.dict() ) # Subscription validation passed, continue with request response = await call_next(request) return response def _is_public_tier_route(self, path: str) -> bool: """ Check if route is explicitly allowed for all subscription tiers Args: path: Request path Returns: True if route is allowed for all tiers """ for pattern in self.public_tier_routes: if re.match(pattern, path): logger.debug("Route allowed for all tiers", path=path, pattern=pattern) return True return False def _should_skip_subscription_check(self, request: Request) -> bool: """Check if subscription validation should be skipped""" path = request.url.path method = request.method # Skip for health checks, auth, and public routes skip_patterns = [ r'/health.*', r'/metrics.*', r'/api/v1/auth/.*', r'/api/v1/subscriptions/.*', # Subscription management itself r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info r'/api/v1/webhooks/.*', # Webhook endpoints - no tenant context r'/docs.*', r'/openapi\.json', # Training monitoring endpoints (WebSocket and status checks) r'/api/v1/tenants/[^/]+/training/jobs/.*/live.*', # WebSocket endpoint r'/api/v1/tenants/[^/]+/training/jobs/.*/status.*', # Status polling endpoint ] # Skip OPTIONS requests (CORS preflight) if method == "OPTIONS": return True for pattern in skip_patterns: if re.match(pattern, path): return True return False def _get_subscription_requirement(self, path: str) -> Optional[Dict[str, str]]: """Get subscription requirement for a given path""" for pattern, requirement in self.protected_routes.items(): if re.match(pattern, path): return requirement return None def _extract_tenant_id(self, request: Request) -> Optional[str]: """Extract tenant ID from request""" # Try to get from URL path first path_match = re.search(r'/api/v1/tenants/([^/]+)/', request.url.path) if path_match: return path_match.group(1) # Try to get from headers tenant_id = request.headers.get('x-tenant-id') if tenant_id: return tenant_id # Try to get from user state (set by auth middleware) if hasattr(request.state, 'user') and request.state.user: return request.state.user.get('tenant_id') return None async def _validate_subscription_tier( self, request: Request, tenant_id: str, feature: Optional[str], minimum_tier: str, allowed_tiers: List[str] ) -> Dict[str, Any]: """ Validate subscription tier access using cached subscription lookup Args: request: FastAPI request tenant_id: Tenant ID feature: Feature name (optional, for additional checks) minimum_tier: Minimum required subscription tier allowed_tiers: List of allowed subscription tiers Returns: Dict with 'allowed' boolean and additional metadata """ try: # Check if JWT already has subscription if hasattr(request.state, 'user') and request.state.user: user_context = request.state.user user_id = user_context.get('user_id', 'unknown') if user_context.get("subscription_from_jwt"): # Use JWT data directly - NO HTTP CALL! current_tier = user_context.get("subscription_tier", "starter") logger.debug("Using subscription tier from JWT (no HTTP call)", tenant_id=tenant_id, current_tier=current_tier, minimum_tier=minimum_tier, allowed_tiers=allowed_tiers) if current_tier not in [tier.lower() for tier in allowed_tiers]: tier_names = ', '.join(allowed_tiers) return { 'allowed': False, 'message': f'This feature requires a {tier_names} subscription plan', 'current_tier': current_tier } await self._log_subscription_access( tenant_id, user_id, feature, current_tier, True, "jwt" ) return { 'allowed': True, 'message': 'Access granted (JWT subscription)', 'current_tier': current_tier } # Use unified HeaderManager for consistent header handling headers = header_manager.get_all_headers_for_proxy(request) # Extract user_id for logging (fallback path) user_id = header_manager.get_header_value(request, 'x-user-id', 'unknown') # Call tenant service fast tier endpoint with caching (fallback for old tokens) timeout_config = httpx.Timeout( connect=1.0, # Connection timeout - very short for cached endpoint read=5.0, # Read timeout - short for cached lookup write=1.0, # Write timeout pool=1.0 # Pool timeout ) async with httpx.AsyncClient(timeout=timeout_config) as client: # Use fast cached tier endpoint tier_response = await client.get( f"{settings.TENANT_SERVICE_URL}/api/v1/subscriptions/{tenant_id}/tier", headers=headers ) if tier_response.status_code != 200: logger.warning( "Failed to get subscription tier from cache", tenant_id=tenant_id, status_code=tier_response.status_code, response_text=tier_response.text ) # Fail open for availability return { 'allowed': True, 'message': 'Access granted (validation service unavailable)', 'current_tier': 'unknown' } tier_data = tier_response.json() current_tier = tier_data.get('tier', 'starter').lower() logger.debug("Subscription tier check (cached)", tenant_id=tenant_id, current_tier=current_tier, minimum_tier=minimum_tier, allowed_tiers=allowed_tiers, cached=tier_data.get('cached', False)) # Check if current tier is in allowed tiers if current_tier not in [tier.lower() for tier in allowed_tiers]: tier_names = ', '.join(allowed_tiers) await self._log_subscription_access( tenant_id, user_id, feature, current_tier, False, "jwt" ) return { 'allowed': False, 'message': f'This feature requires a {tier_names} subscription plan', 'current_tier': current_tier } # Tier check passed await self._log_subscription_access( tenant_id, user_id, feature, current_tier, True, "database" ) return { 'allowed': True, 'message': 'Access granted', 'current_tier': current_tier } except asyncio.TimeoutError: logger.error( "Timeout validating subscription", tenant_id=tenant_id, feature=feature ) # Fail open for availability (let service handle detailed check) return { 'allowed': True, 'message': 'Access granted (validation timeout)', 'current_plan': 'unknown' } except httpx.RequestError as e: logger.error( "Request error validating subscription", tenant_id=tenant_id, feature=feature, error=str(e) ) # Fail open for availability return { 'allowed': True, 'message': 'Access granted (validation service unavailable)', 'current_plan': 'unknown' } except Exception as e: logger.error( "Subscription validation error", tenant_id=tenant_id, feature=feature, error=str(e) ) # Fail open for availability (let service handle detailed check) return { 'allowed': True, 'message': 'Access granted (validation error)', 'current_plan': 'unknown' } async def _log_subscription_access( self, tenant_id: str, user_id: str, requested_feature: str, current_tier: str, access_granted: bool, source: str # "jwt" or "database" ): """ Log all subscription-gated access attempts for audit and anomaly detection. """ logger.info( "Subscription access check", tenant_id=tenant_id, user_id=user_id, feature=requested_feature, tier=current_tier, granted=access_granted, source=source, timestamp=datetime.now(timezone.utc).isoformat() ) # For denied access, check for suspicious patterns if not access_granted: await self._check_for_abuse_patterns(tenant_id, user_id, requested_feature) async def _check_for_abuse_patterns( self, tenant_id: str, user_id: str, feature: str ): """ Detect potential abuse patterns like repeated premium feature access attempts. """ if not self.redis_client: return # Track denied attempts in a sliding window key = f"subscription_denied:{tenant_id}:{user_id}:{feature}" try: attempts = await self.redis_client.incr(key) if attempts == 1: await self.redis_client.expire(key, 3600) # 1 hour window # Alert if too many denied attempts (potential bypass attempt) if attempts > 10: logger.warning( "SECURITY: Excessive premium feature access attempts detected", tenant_id=tenant_id, user_id=user_id, feature=feature, attempts=attempts, window="1 hour" ) # Could trigger alert to security team here except Exception as e: logger.warning("Failed to track abuse patterns", error=str(e))