270 lines
9.1 KiB
Python
270 lines
9.1 KiB
Python
"""
|
|
API Rate Limiting Middleware for Gateway
|
|
Enforces subscription-based API call quotas per hour
|
|
"""
|
|
|
|
import structlog
|
|
import shared.redis_utils
|
|
from datetime import datetime, timezone
|
|
from fastapi import Request, HTTPException, status
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from typing import Optional
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class APIRateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
Middleware to enforce API rate limits based on subscription tier.
|
|
|
|
Quota limits per hour:
|
|
- Starter: 100 calls/hour
|
|
- Professional: 1,000 calls/hour
|
|
- Enterprise: 10,000 calls/hour
|
|
|
|
Uses Redis to track API calls with hourly buckets.
|
|
"""
|
|
|
|
def __init__(self, app, redis_client=None):
|
|
super().__init__(app)
|
|
self.redis_client = redis_client
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
"""
|
|
Check API rate limit before processing request.
|
|
"""
|
|
# Skip rate limiting for certain paths
|
|
if self._should_skip_rate_limit(request.url.path):
|
|
return await call_next(request)
|
|
|
|
# Extract tenant_id from request
|
|
tenant_id = self._extract_tenant_id(request)
|
|
|
|
if not tenant_id:
|
|
# No tenant ID - skip rate limiting for auth/public endpoints
|
|
return await call_next(request)
|
|
|
|
try:
|
|
# Get subscription tier from headers (added by AuthMiddleware)
|
|
subscription_tier = request.headers.get("x-subscription-tier")
|
|
|
|
if not subscription_tier:
|
|
# Fallback: get from request state if headers not available
|
|
subscription_tier = getattr(request.state, "subscription_tier", None)
|
|
|
|
if not subscription_tier:
|
|
# Final fallback: get from tenant service (should rarely happen)
|
|
subscription_tier = await self._get_subscription_tier(tenant_id, request)
|
|
logger.warning(f"Subscription tier not found in headers or state, fetched from tenant service: {subscription_tier}")
|
|
|
|
# Get quota limit for tier
|
|
quota_limit = self._get_quota_limit(subscription_tier)
|
|
|
|
# Check and increment quota
|
|
allowed, current_count = await self._check_and_increment_quota(
|
|
tenant_id,
|
|
quota_limit
|
|
)
|
|
|
|
if not allowed:
|
|
logger.warning(
|
|
"API rate limit exceeded",
|
|
tenant_id=tenant_id,
|
|
subscription_tier=subscription_tier,
|
|
current_count=current_count,
|
|
quota_limit=quota_limit,
|
|
path=request.url.path
|
|
)
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail={
|
|
"error": "rate_limit_exceeded",
|
|
"message": f"API rate limit exceeded. Maximum {quota_limit} calls per hour allowed for {subscription_tier} plan.",
|
|
"current_count": current_count,
|
|
"quota_limit": quota_limit,
|
|
"reset_time": self._get_reset_time(),
|
|
"upgrade_required": subscription_tier in ['starter', 'professional']
|
|
}
|
|
)
|
|
|
|
# Add rate limit headers to response
|
|
response = await call_next(request)
|
|
response.headers["X-RateLimit-Limit"] = str(quota_limit)
|
|
response.headers["X-RateLimit-Remaining"] = str(max(0, quota_limit - current_count))
|
|
response.headers["X-RateLimit-Reset"] = self._get_reset_time()
|
|
|
|
return response
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(
|
|
"Rate limiting check failed, allowing request",
|
|
tenant_id=tenant_id,
|
|
error=str(e),
|
|
path=request.url.path
|
|
)
|
|
# Fail open - allow request if rate limiting fails
|
|
return await call_next(request)
|
|
|
|
def _should_skip_rate_limit(self, path: str) -> bool:
|
|
"""
|
|
Determine if path should skip rate limiting.
|
|
"""
|
|
skip_paths = [
|
|
"/health",
|
|
"/metrics",
|
|
"/docs",
|
|
"/openapi.json",
|
|
"/api/v1/auth/",
|
|
"/api/v1/plans", # Public pricing info
|
|
]
|
|
|
|
for skip_path in skip_paths:
|
|
if path.startswith(skip_path):
|
|
return True
|
|
|
|
return False
|
|
|
|
def _extract_tenant_id(self, request: Request) -> Optional[str]:
|
|
"""
|
|
Extract tenant ID from request headers or path.
|
|
"""
|
|
# Try header first
|
|
tenant_id = request.headers.get("x-tenant-id")
|
|
if tenant_id:
|
|
return tenant_id
|
|
|
|
# Try to extract from path /api/v1/tenants/{tenant_id}/...
|
|
path_parts = request.url.path.split("/")
|
|
if "tenants" in path_parts:
|
|
try:
|
|
tenant_index = path_parts.index("tenants")
|
|
if len(path_parts) > tenant_index + 1:
|
|
return path_parts[tenant_index + 1]
|
|
except (ValueError, IndexError):
|
|
pass
|
|
|
|
return None
|
|
|
|
async def _get_subscription_tier(self, tenant_id: str, request: Request) -> str:
|
|
"""
|
|
Get subscription tier from tenant service (with caching).
|
|
"""
|
|
try:
|
|
# Try to get from request state (if subscription middleware already ran)
|
|
if hasattr(request.state, "subscription_tier"):
|
|
return request.state.subscription_tier
|
|
|
|
# Call tenant service to get tier
|
|
import httpx
|
|
from gateway.app.core.config import settings
|
|
|
|
async with httpx.AsyncClient(timeout=2.0) as client:
|
|
response = await client.get(
|
|
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
|
|
headers={
|
|
"x-service": "gateway"
|
|
}
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
return data.get("tier", "starter")
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Failed to get subscription tier, defaulting to starter",
|
|
tenant_id=tenant_id,
|
|
error=str(e)
|
|
)
|
|
|
|
return "starter"
|
|
|
|
def _get_quota_limit(self, subscription_tier: str) -> int:
|
|
"""
|
|
Get API calls per hour quota for subscription tier.
|
|
"""
|
|
quota_map = {
|
|
"starter": 100,
|
|
"professional": 1000,
|
|
"enterprise": 10000,
|
|
"demo": 1000, # Same as professional
|
|
}
|
|
|
|
return quota_map.get(subscription_tier.lower(), 100)
|
|
|
|
async def _check_and_increment_quota(
|
|
self,
|
|
tenant_id: str,
|
|
quota_limit: int
|
|
) -> tuple[bool, int]:
|
|
"""
|
|
Check current quota usage and increment counter.
|
|
|
|
Returns:
|
|
(allowed: bool, current_count: int)
|
|
"""
|
|
if not self.redis_client:
|
|
# No Redis - fail open
|
|
return True, 0
|
|
|
|
try:
|
|
# Create hourly bucket key
|
|
current_hour = datetime.now(timezone.utc).strftime("%Y-%m-%d-%H")
|
|
quota_key = f"quota:hourly:api_calls:{tenant_id}:{current_hour}"
|
|
|
|
# Get current count
|
|
current_count = await self.redis_client.get(quota_key)
|
|
current_count = int(current_count) if current_count else 0
|
|
|
|
# Check if within limit
|
|
if current_count >= quota_limit:
|
|
return False, current_count
|
|
|
|
# Increment counter
|
|
new_count = await self.redis_client.incr(quota_key)
|
|
|
|
# Set expiry (1 hour + 5 minutes buffer)
|
|
await self.redis_client.expire(quota_key, 3900)
|
|
|
|
return True, new_count
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Redis quota check failed",
|
|
tenant_id=tenant_id,
|
|
error=str(e)
|
|
)
|
|
# Fail open
|
|
return True, 0
|
|
|
|
def _get_reset_time(self) -> str:
|
|
"""
|
|
Get the reset time for the current hour bucket (top of next hour).
|
|
"""
|
|
from datetime import timedelta
|
|
|
|
now = datetime.now(timezone.utc)
|
|
next_hour = (now + timedelta(hours=1)).replace(minute=0, second=0, microsecond=0)
|
|
|
|
return next_hour.isoformat()
|
|
|
|
|
|
async def get_rate_limit_middleware(app):
|
|
"""
|
|
Factory function to create rate limiting middleware with Redis client.
|
|
"""
|
|
try:
|
|
from gateway.app.core.config import settings
|
|
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
|
|
|
logger.info("API rate limiting middleware initialized with Redis")
|
|
return APIRateLimitMiddleware(app, redis_client=redis_client)
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Failed to initialize Redis for rate limiting, middleware will fail open",
|
|
error=str(e)
|
|
)
|
|
return APIRateLimitMiddleware(app, redis_client=None)
|