Improve onboarding
This commit is contained in:
260
gateway/app/middleware/rate_limiting.py
Normal file
260
gateway/app/middleware/rate_limiting.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
API Rate Limiting Middleware for Gateway
|
||||
Enforces subscription-based API call quotas per hour
|
||||
"""
|
||||
|
||||
import structlog
|
||||
import shared.redis_utils
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import Request, HTTPException, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from typing import Optional
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class APIRateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to enforce API rate limits based on subscription tier.
|
||||
|
||||
Quota limits per hour:
|
||||
- Starter: 100 calls/hour
|
||||
- Professional: 1,000 calls/hour
|
||||
- Enterprise: 10,000 calls/hour
|
||||
|
||||
Uses Redis to track API calls with hourly buckets.
|
||||
"""
|
||||
|
||||
def __init__(self, app, redis_client=None):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
Check API rate limit before processing request.
|
||||
"""
|
||||
# Skip rate limiting for certain paths
|
||||
if self._should_skip_rate_limit(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract tenant_id from request
|
||||
tenant_id = self._extract_tenant_id(request)
|
||||
|
||||
if not tenant_id:
|
||||
# No tenant ID - skip rate limiting for auth/public endpoints
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
# Get subscription tier
|
||||
subscription_tier = await self._get_subscription_tier(tenant_id, request)
|
||||
|
||||
# Get quota limit for tier
|
||||
quota_limit = self._get_quota_limit(subscription_tier)
|
||||
|
||||
# Check and increment quota
|
||||
allowed, current_count = await self._check_and_increment_quota(
|
||||
tenant_id,
|
||||
quota_limit
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
"API rate limit exceeded",
|
||||
tenant_id=tenant_id,
|
||||
subscription_tier=subscription_tier,
|
||||
current_count=current_count,
|
||||
quota_limit=quota_limit,
|
||||
path=request.url.path
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": f"API rate limit exceeded. Maximum {quota_limit} calls per hour allowed for {subscription_tier} plan.",
|
||||
"current_count": current_count,
|
||||
"quota_limit": quota_limit,
|
||||
"reset_time": self._get_reset_time(),
|
||||
"upgrade_required": subscription_tier in ['starter', 'professional']
|
||||
}
|
||||
)
|
||||
|
||||
# Add rate limit headers to response
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(quota_limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(max(0, quota_limit - current_count))
|
||||
response.headers["X-RateLimit-Reset"] = self._get_reset_time()
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Rate limiting check failed, allowing request",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
path=request.url.path
|
||||
)
|
||||
# Fail open - allow request if rate limiting fails
|
||||
return await call_next(request)
|
||||
|
||||
def _should_skip_rate_limit(self, path: str) -> bool:
|
||||
"""
|
||||
Determine if path should skip rate limiting.
|
||||
"""
|
||||
skip_paths = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/",
|
||||
"/api/v1/plans", # Public pricing info
|
||||
]
|
||||
|
||||
for skip_path in skip_paths:
|
||||
if path.startswith(skip_path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _extract_tenant_id(self, request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract tenant ID from request headers or path.
|
||||
"""
|
||||
# Try header first
|
||||
tenant_id = request.headers.get("x-tenant-id")
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Try to extract from path /api/v1/tenants/{tenant_id}/...
|
||||
path_parts = request.url.path.split("/")
|
||||
if "tenants" in path_parts:
|
||||
try:
|
||||
tenant_index = path_parts.index("tenants")
|
||||
if len(path_parts) > tenant_index + 1:
|
||||
return path_parts[tenant_index + 1]
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def _get_subscription_tier(self, tenant_id: str, request: Request) -> str:
|
||||
"""
|
||||
Get subscription tier from tenant service (with caching).
|
||||
"""
|
||||
try:
|
||||
# Try to get from request state (if subscription middleware already ran)
|
||||
if hasattr(request.state, "subscription_tier"):
|
||||
return request.state.subscription_tier
|
||||
|
||||
# Call tenant service to get tier
|
||||
import httpx
|
||||
from gateway.app.core.config import settings
|
||||
|
||||
async with httpx.AsyncClient(timeout=2.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/subscriptions/{tenant_id}/tier",
|
||||
headers={
|
||||
"x-service": "gateway"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data.get("tier", "starter")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to get subscription tier, defaulting to starter",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
return "starter"
|
||||
|
||||
def _get_quota_limit(self, subscription_tier: str) -> int:
|
||||
"""
|
||||
Get API calls per hour quota for subscription tier.
|
||||
"""
|
||||
quota_map = {
|
||||
"starter": 100,
|
||||
"professional": 1000,
|
||||
"enterprise": 10000,
|
||||
"demo": 1000, # Same as professional
|
||||
}
|
||||
|
||||
return quota_map.get(subscription_tier.lower(), 100)
|
||||
|
||||
async def _check_and_increment_quota(
|
||||
self,
|
||||
tenant_id: str,
|
||||
quota_limit: int
|
||||
) -> tuple[bool, int]:
|
||||
"""
|
||||
Check current quota usage and increment counter.
|
||||
|
||||
Returns:
|
||||
(allowed: bool, current_count: int)
|
||||
"""
|
||||
if not self.redis_client:
|
||||
# No Redis - fail open
|
||||
return True, 0
|
||||
|
||||
try:
|
||||
# Create hourly bucket key
|
||||
current_hour = datetime.now(timezone.utc).strftime("%Y-%m-%d-%H")
|
||||
quota_key = f"quota:hourly:api_calls:{tenant_id}:{current_hour}"
|
||||
|
||||
# Get current count
|
||||
current_count = await self.redis_client.get(quota_key)
|
||||
current_count = int(current_count) if current_count else 0
|
||||
|
||||
# Check if within limit
|
||||
if current_count >= quota_limit:
|
||||
return False, current_count
|
||||
|
||||
# Increment counter
|
||||
new_count = await self.redis_client.incr(quota_key)
|
||||
|
||||
# Set expiry (1 hour + 5 minutes buffer)
|
||||
await self.redis_client.expire(quota_key, 3900)
|
||||
|
||||
return True, new_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Redis quota check failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
# Fail open
|
||||
return True, 0
|
||||
|
||||
def _get_reset_time(self) -> str:
|
||||
"""
|
||||
Get the reset time for the current hour bucket (top of next hour).
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
next_hour = (now + timedelta(hours=1)).replace(minute=0, second=0, microsecond=0)
|
||||
|
||||
return next_hour.isoformat()
|
||||
|
||||
|
||||
async def get_rate_limit_middleware(app):
|
||||
"""
|
||||
Factory function to create rate limiting middleware with Redis client.
|
||||
"""
|
||||
try:
|
||||
from gateway.app.core.config import settings
|
||||
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
||||
|
||||
logger.info("API rate limiting middleware initialized with Redis")
|
||||
return APIRateLimitMiddleware(app, redis_client=redis_client)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to initialize Redis for rate limiting, middleware will fail open",
|
||||
error=str(e)
|
||||
)
|
||||
return APIRateLimitMiddleware(app, redis_client=None)
|
||||
Reference in New Issue
Block a user