""" API Rate Limiting Middleware for Gateway Enforces subscription-based API call quotas per hour """ import structlog import shared.redis_utils from datetime import datetime, timezone from fastapi import Request, HTTPException, status from starlette.middleware.base import BaseHTTPMiddleware from typing import Optional logger = structlog.get_logger() class APIRateLimitMiddleware(BaseHTTPMiddleware): """ Middleware to enforce API rate limits based on subscription tier. Quota limits per hour: - Starter: 100 calls/hour - Professional: 1,000 calls/hour - Enterprise: 10,000 calls/hour Uses Redis to track API calls with hourly buckets. """ def __init__(self, app, redis_client=None): super().__init__(app) self.redis_client = redis_client async def dispatch(self, request: Request, call_next): """ Check API rate limit before processing request. """ # Skip rate limiting for certain paths if self._should_skip_rate_limit(request.url.path): return await call_next(request) # Extract tenant_id from request tenant_id = self._extract_tenant_id(request) if not tenant_id: # No tenant ID - skip rate limiting for auth/public endpoints return await call_next(request) try: # Get subscription tier from headers (added by AuthMiddleware) subscription_tier = request.headers.get("x-subscription-tier") if not subscription_tier: # Fallback: get from request state if headers not available subscription_tier = getattr(request.state, "subscription_tier", None) if not subscription_tier: # Final fallback: get from tenant service (should rarely happen) subscription_tier = await self._get_subscription_tier(tenant_id, request) logger.warning(f"Subscription tier not found in headers or state, fetched from tenant service: {subscription_tier}") # Get quota limit for tier quota_limit = self._get_quota_limit(subscription_tier) # Check and increment quota allowed, current_count = await self._check_and_increment_quota( tenant_id, quota_limit ) if not allowed: logger.warning( "API rate limit exceeded", tenant_id=tenant_id, subscription_tier=subscription_tier, current_count=current_count, quota_limit=quota_limit, path=request.url.path ) raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail={ "error": "rate_limit_exceeded", "message": f"API rate limit exceeded. Maximum {quota_limit} calls per hour allowed for {subscription_tier} plan.", "current_count": current_count, "quota_limit": quota_limit, "reset_time": self._get_reset_time(), "upgrade_required": subscription_tier in ['starter', 'professional'] } ) # Add rate limit headers to response response = await call_next(request) response.headers["X-RateLimit-Limit"] = str(quota_limit) response.headers["X-RateLimit-Remaining"] = str(max(0, quota_limit - current_count)) response.headers["X-RateLimit-Reset"] = self._get_reset_time() return response except HTTPException: raise except Exception as e: logger.error( "Rate limiting check failed, allowing request", tenant_id=tenant_id, error=str(e), path=request.url.path ) # Fail open - allow request if rate limiting fails return await call_next(request) def _should_skip_rate_limit(self, path: str) -> bool: """ Determine if path should skip rate limiting. """ skip_paths = [ "/health", "/metrics", "/docs", "/openapi.json", "/api/v1/auth/", "/api/v1/plans", # Public pricing info ] for skip_path in skip_paths: if path.startswith(skip_path): return True return False def _extract_tenant_id(self, request: Request) -> Optional[str]: """ Extract tenant ID from request headers or path. """ # Try header first tenant_id = request.headers.get("x-tenant-id") if tenant_id: return tenant_id # Try to extract from path /api/v1/tenants/{tenant_id}/... path_parts = request.url.path.split("/") if "tenants" in path_parts: try: tenant_index = path_parts.index("tenants") if len(path_parts) > tenant_index + 1: return path_parts[tenant_index + 1] except (ValueError, IndexError): pass return None async def _get_subscription_tier(self, tenant_id: str, request: Request) -> str: """ Get subscription tier from tenant service (with caching). """ try: # Try to get from request state (if subscription middleware already ran) if hasattr(request.state, "subscription_tier"): return request.state.subscription_tier # Call tenant service to get tier import httpx from gateway.app.core.config import settings async with httpx.AsyncClient(timeout=2.0) as client: response = await client.get( f"{settings.TENANT_SERVICE_URL}/api/v1/subscriptions/{tenant_id}/tier", headers={ "x-service": "gateway" } ) if response.status_code == 200: data = response.json() return data.get("tier", "starter") except Exception as e: logger.warning( "Failed to get subscription tier, defaulting to starter", tenant_id=tenant_id, error=str(e) ) return "starter" def _get_quota_limit(self, subscription_tier: str) -> int: """ Get API calls per hour quota for subscription tier. """ quota_map = { "starter": 100, "professional": 1000, "enterprise": 10000, "demo": 1000, # Same as professional } return quota_map.get(subscription_tier.lower(), 100) async def _check_and_increment_quota( self, tenant_id: str, quota_limit: int ) -> tuple[bool, int]: """ Check current quota usage and increment counter. Returns: (allowed: bool, current_count: int) """ if not self.redis_client: # No Redis - fail open return True, 0 try: # Create hourly bucket key current_hour = datetime.now(timezone.utc).strftime("%Y-%m-%d-%H") quota_key = f"quota:hourly:api_calls:{tenant_id}:{current_hour}" # Get current count current_count = await self.redis_client.get(quota_key) current_count = int(current_count) if current_count else 0 # Check if within limit if current_count >= quota_limit: return False, current_count # Increment counter new_count = await self.redis_client.incr(quota_key) # Set expiry (1 hour + 5 minutes buffer) await self.redis_client.expire(quota_key, 3900) return True, new_count except Exception as e: logger.error( "Redis quota check failed", tenant_id=tenant_id, error=str(e) ) # Fail open return True, 0 def _get_reset_time(self) -> str: """ Get the reset time for the current hour bucket (top of next hour). """ from datetime import timedelta now = datetime.now(timezone.utc) next_hour = (now + timedelta(hours=1)).replace(minute=0, second=0, microsecond=0) return next_hour.isoformat() async def get_rate_limit_middleware(app): """ Factory function to create rate limiting middleware with Redis client. """ try: from gateway.app.core.config import settings redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL) logger.info("API rate limiting middleware initialized with Redis") return APIRateLimitMiddleware(app, redis_client=redis_client) except Exception as e: logger.warning( "Failed to initialize Redis for rate limiting, middleware will fail open", error=str(e) ) return APIRateLimitMiddleware(app, redis_client=None)