Files
bakery-ia/gateway/app/middleware/rate_limiting.py

270 lines
9.1 KiB
Python
Raw Normal View History

2025-12-18 13:26:32 +01:00
"""
API Rate Limiting Middleware for Gateway
Enforces subscription-based API call quotas per hour
"""
import structlog
import shared.redis_utils
from datetime import datetime, timezone
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Optional
logger = structlog.get_logger()
class APIRateLimitMiddleware(BaseHTTPMiddleware):
"""
Middleware to enforce API rate limits based on subscription tier.
Quota limits per hour:
- Starter: 100 calls/hour
- Professional: 1,000 calls/hour
- Enterprise: 10,000 calls/hour
Uses Redis to track API calls with hourly buckets.
"""
def __init__(self, app, redis_client=None):
super().__init__(app)
self.redis_client = redis_client
async def dispatch(self, request: Request, call_next):
"""
Check API rate limit before processing request.
"""
# Skip rate limiting for certain paths
if self._should_skip_rate_limit(request.url.path):
return await call_next(request)
# Extract tenant_id from request
tenant_id = self._extract_tenant_id(request)
if not tenant_id:
# No tenant ID - skip rate limiting for auth/public endpoints
return await call_next(request)
try:
2026-01-12 22:15:11 +01:00
# Get subscription tier from headers (added by AuthMiddleware)
subscription_tier = request.headers.get("x-subscription-tier")
if not subscription_tier:
# Fallback: get from request state if headers not available
subscription_tier = getattr(request.state, "subscription_tier", None)
if not subscription_tier:
# Final fallback: get from tenant service (should rarely happen)
subscription_tier = await self._get_subscription_tier(tenant_id, request)
logger.warning(f"Subscription tier not found in headers or state, fetched from tenant service: {subscription_tier}")
2025-12-18 13:26:32 +01:00
# Get quota limit for tier
quota_limit = self._get_quota_limit(subscription_tier)
# Check and increment quota
allowed, current_count = await self._check_and_increment_quota(
tenant_id,
quota_limit
)
if not allowed:
logger.warning(
"API rate limit exceeded",
tenant_id=tenant_id,
subscription_tier=subscription_tier,
current_count=current_count,
quota_limit=quota_limit,
path=request.url.path
)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail={
"error": "rate_limit_exceeded",
"message": f"API rate limit exceeded. Maximum {quota_limit} calls per hour allowed for {subscription_tier} plan.",
"current_count": current_count,
"quota_limit": quota_limit,
"reset_time": self._get_reset_time(),
"upgrade_required": subscription_tier in ['starter', 'professional']
}
)
# Add rate limit headers to response
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(quota_limit)
response.headers["X-RateLimit-Remaining"] = str(max(0, quota_limit - current_count))
response.headers["X-RateLimit-Reset"] = self._get_reset_time()
return response
except HTTPException:
raise
except Exception as e:
logger.error(
"Rate limiting check failed, allowing request",
tenant_id=tenant_id,
error=str(e),
path=request.url.path
)
# Fail open - allow request if rate limiting fails
return await call_next(request)
def _should_skip_rate_limit(self, path: str) -> bool:
"""
Determine if path should skip rate limiting.
"""
skip_paths = [
"/health",
"/metrics",
"/docs",
"/openapi.json",
"/api/v1/auth/",
"/api/v1/plans", # Public pricing info
]
for skip_path in skip_paths:
if path.startswith(skip_path):
return True
return False
def _extract_tenant_id(self, request: Request) -> Optional[str]:
"""
Extract tenant ID from request headers or path.
"""
# Try header first
tenant_id = request.headers.get("x-tenant-id")
if tenant_id:
return tenant_id
# Try to extract from path /api/v1/tenants/{tenant_id}/...
path_parts = request.url.path.split("/")
if "tenants" in path_parts:
try:
tenant_index = path_parts.index("tenants")
if len(path_parts) > tenant_index + 1:
return path_parts[tenant_index + 1]
except (ValueError, IndexError):
pass
return None
async def _get_subscription_tier(self, tenant_id: str, request: Request) -> str:
"""
Get subscription tier from tenant service (with caching).
"""
try:
# Try to get from request state (if subscription middleware already ran)
if hasattr(request.state, "subscription_tier"):
return request.state.subscription_tier
# Call tenant service to get tier
import httpx
from gateway.app.core.config import settings
async with httpx.AsyncClient(timeout=2.0) as client:
response = await client.get(
2026-01-16 15:19:34 +01:00
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
2025-12-18 13:26:32 +01:00
headers={
"x-service": "gateway"
}
)
if response.status_code == 200:
data = response.json()
return data.get("tier", "starter")
except Exception as e:
logger.warning(
"Failed to get subscription tier, defaulting to starter",
tenant_id=tenant_id,
error=str(e)
)
return "starter"
def _get_quota_limit(self, subscription_tier: str) -> int:
"""
Get API calls per hour quota for subscription tier.
"""
quota_map = {
"starter": 100,
"professional": 1000,
"enterprise": 10000,
"demo": 1000, # Same as professional
}
return quota_map.get(subscription_tier.lower(), 100)
async def _check_and_increment_quota(
self,
tenant_id: str,
quota_limit: int
) -> tuple[bool, int]:
"""
Check current quota usage and increment counter.
Returns:
(allowed: bool, current_count: int)
"""
if not self.redis_client:
# No Redis - fail open
return True, 0
try:
# Create hourly bucket key
current_hour = datetime.now(timezone.utc).strftime("%Y-%m-%d-%H")
quota_key = f"quota:hourly:api_calls:{tenant_id}:{current_hour}"
# Get current count
current_count = await self.redis_client.get(quota_key)
current_count = int(current_count) if current_count else 0
# Check if within limit
if current_count >= quota_limit:
return False, current_count
# Increment counter
new_count = await self.redis_client.incr(quota_key)
# Set expiry (1 hour + 5 minutes buffer)
await self.redis_client.expire(quota_key, 3900)
return True, new_count
except Exception as e:
logger.error(
"Redis quota check failed",
tenant_id=tenant_id,
error=str(e)
)
# Fail open
return True, 0
def _get_reset_time(self) -> str:
"""
Get the reset time for the current hour bucket (top of next hour).
"""
from datetime import timedelta
now = datetime.now(timezone.utc)
next_hour = (now + timedelta(hours=1)).replace(minute=0, second=0, microsecond=0)
return next_hour.isoformat()
async def get_rate_limit_middleware(app):
"""
Factory function to create rate limiting middleware with Redis client.
"""
try:
from gateway.app.core.config import settings
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
logger.info("API rate limiting middleware initialized with Redis")
return APIRateLimitMiddleware(app, redis_client=redis_client)
except Exception as e:
logger.warning(
"Failed to initialize Redis for rate limiting, middleware will fail open",
error=str(e)
)
return APIRateLimitMiddleware(app, redis_client=None)