Improve onboarding
This commit is contained in:
@@ -21,6 +21,7 @@ from app.middleware.request_id import RequestIDMiddleware
|
||||
from app.middleware.auth import AuthMiddleware
|
||||
from app.middleware.logging import LoggingMiddleware
|
||||
from app.middleware.rate_limit import RateLimitMiddleware
|
||||
from app.middleware.rate_limiting import APIRateLimitMiddleware
|
||||
from app.middleware.subscription import SubscriptionMiddleware
|
||||
from app.middleware.demo_middleware import DemoMiddleware
|
||||
from app.middleware.read_only_mode import ReadOnlyModeMiddleware
|
||||
@@ -90,9 +91,10 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
# Custom middleware - Add in REVERSE order (last added = first executed)
|
||||
# Execution order: RequestIDMiddleware -> DemoMiddleware -> AuthMiddleware -> ReadOnlyModeMiddleware -> SubscriptionMiddleware -> RateLimitMiddleware -> LoggingMiddleware
|
||||
app.add_middleware(LoggingMiddleware) # Executes 7th (outermost)
|
||||
app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 6th
|
||||
# Execution order: RequestIDMiddleware -> DemoMiddleware -> AuthMiddleware -> ReadOnlyModeMiddleware -> SubscriptionMiddleware -> APIRateLimitMiddleware -> RateLimitMiddleware -> LoggingMiddleware
|
||||
app.add_middleware(LoggingMiddleware) # Executes 8th (outermost)
|
||||
app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 7th - Simple rate limit
|
||||
# Note: APIRateLimitMiddleware will be added on startup with Redis client
|
||||
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 5th
|
||||
app.add_middleware(ReadOnlyModeMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 4th - Enforce read-only mode
|
||||
app.add_middleware(AuthMiddleware) # Executes 3rd - Checks for demo context
|
||||
@@ -123,8 +125,13 @@ async def startup_event():
|
||||
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
|
||||
redis_client = await get_redis_client()
|
||||
logger.info("Connected to Redis for SSE streaming")
|
||||
|
||||
# Add API rate limiting middleware with Redis client
|
||||
app.add_middleware(APIRateLimitMiddleware, redis_client=redis_client)
|
||||
logger.info("API rate limiting middleware enabled with subscription-based quotas")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}")
|
||||
logger.warning("API rate limiting middleware will fail open (allow all requests)")
|
||||
|
||||
metrics_collector.register_counter(
|
||||
"gateway_auth_requests_total",
|
||||
|
||||
260
gateway/app/middleware/rate_limiting.py
Normal file
260
gateway/app/middleware/rate_limiting.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
API Rate Limiting Middleware for Gateway
|
||||
Enforces subscription-based API call quotas per hour
|
||||
"""
|
||||
|
||||
import structlog
|
||||
import shared.redis_utils
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import Request, HTTPException, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from typing import Optional
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class APIRateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to enforce API rate limits based on subscription tier.
|
||||
|
||||
Quota limits per hour:
|
||||
- Starter: 100 calls/hour
|
||||
- Professional: 1,000 calls/hour
|
||||
- Enterprise: 10,000 calls/hour
|
||||
|
||||
Uses Redis to track API calls with hourly buckets.
|
||||
"""
|
||||
|
||||
def __init__(self, app, redis_client=None):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
Check API rate limit before processing request.
|
||||
"""
|
||||
# Skip rate limiting for certain paths
|
||||
if self._should_skip_rate_limit(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract tenant_id from request
|
||||
tenant_id = self._extract_tenant_id(request)
|
||||
|
||||
if not tenant_id:
|
||||
# No tenant ID - skip rate limiting for auth/public endpoints
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
# Get subscription tier
|
||||
subscription_tier = await self._get_subscription_tier(tenant_id, request)
|
||||
|
||||
# Get quota limit for tier
|
||||
quota_limit = self._get_quota_limit(subscription_tier)
|
||||
|
||||
# Check and increment quota
|
||||
allowed, current_count = await self._check_and_increment_quota(
|
||||
tenant_id,
|
||||
quota_limit
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
"API rate limit exceeded",
|
||||
tenant_id=tenant_id,
|
||||
subscription_tier=subscription_tier,
|
||||
current_count=current_count,
|
||||
quota_limit=quota_limit,
|
||||
path=request.url.path
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": f"API rate limit exceeded. Maximum {quota_limit} calls per hour allowed for {subscription_tier} plan.",
|
||||
"current_count": current_count,
|
||||
"quota_limit": quota_limit,
|
||||
"reset_time": self._get_reset_time(),
|
||||
"upgrade_required": subscription_tier in ['starter', 'professional']
|
||||
}
|
||||
)
|
||||
|
||||
# Add rate limit headers to response
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(quota_limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(max(0, quota_limit - current_count))
|
||||
response.headers["X-RateLimit-Reset"] = self._get_reset_time()
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Rate limiting check failed, allowing request",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
path=request.url.path
|
||||
)
|
||||
# Fail open - allow request if rate limiting fails
|
||||
return await call_next(request)
|
||||
|
||||
def _should_skip_rate_limit(self, path: str) -> bool:
|
||||
"""
|
||||
Determine if path should skip rate limiting.
|
||||
"""
|
||||
skip_paths = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/",
|
||||
"/api/v1/plans", # Public pricing info
|
||||
]
|
||||
|
||||
for skip_path in skip_paths:
|
||||
if path.startswith(skip_path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _extract_tenant_id(self, request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract tenant ID from request headers or path.
|
||||
"""
|
||||
# Try header first
|
||||
tenant_id = request.headers.get("x-tenant-id")
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Try to extract from path /api/v1/tenants/{tenant_id}/...
|
||||
path_parts = request.url.path.split("/")
|
||||
if "tenants" in path_parts:
|
||||
try:
|
||||
tenant_index = path_parts.index("tenants")
|
||||
if len(path_parts) > tenant_index + 1:
|
||||
return path_parts[tenant_index + 1]
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def _get_subscription_tier(self, tenant_id: str, request: Request) -> str:
|
||||
"""
|
||||
Get subscription tier from tenant service (with caching).
|
||||
"""
|
||||
try:
|
||||
# Try to get from request state (if subscription middleware already ran)
|
||||
if hasattr(request.state, "subscription_tier"):
|
||||
return request.state.subscription_tier
|
||||
|
||||
# Call tenant service to get tier
|
||||
import httpx
|
||||
from gateway.app.core.config import settings
|
||||
|
||||
async with httpx.AsyncClient(timeout=2.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/subscriptions/{tenant_id}/tier",
|
||||
headers={
|
||||
"x-service": "gateway"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data.get("tier", "starter")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to get subscription tier, defaulting to starter",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
return "starter"
|
||||
|
||||
def _get_quota_limit(self, subscription_tier: str) -> int:
|
||||
"""
|
||||
Get API calls per hour quota for subscription tier.
|
||||
"""
|
||||
quota_map = {
|
||||
"starter": 100,
|
||||
"professional": 1000,
|
||||
"enterprise": 10000,
|
||||
"demo": 1000, # Same as professional
|
||||
}
|
||||
|
||||
return quota_map.get(subscription_tier.lower(), 100)
|
||||
|
||||
async def _check_and_increment_quota(
|
||||
self,
|
||||
tenant_id: str,
|
||||
quota_limit: int
|
||||
) -> tuple[bool, int]:
|
||||
"""
|
||||
Check current quota usage and increment counter.
|
||||
|
||||
Returns:
|
||||
(allowed: bool, current_count: int)
|
||||
"""
|
||||
if not self.redis_client:
|
||||
# No Redis - fail open
|
||||
return True, 0
|
||||
|
||||
try:
|
||||
# Create hourly bucket key
|
||||
current_hour = datetime.now(timezone.utc).strftime("%Y-%m-%d-%H")
|
||||
quota_key = f"quota:hourly:api_calls:{tenant_id}:{current_hour}"
|
||||
|
||||
# Get current count
|
||||
current_count = await self.redis_client.get(quota_key)
|
||||
current_count = int(current_count) if current_count else 0
|
||||
|
||||
# Check if within limit
|
||||
if current_count >= quota_limit:
|
||||
return False, current_count
|
||||
|
||||
# Increment counter
|
||||
new_count = await self.redis_client.incr(quota_key)
|
||||
|
||||
# Set expiry (1 hour + 5 minutes buffer)
|
||||
await self.redis_client.expire(quota_key, 3900)
|
||||
|
||||
return True, new_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Redis quota check failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
# Fail open
|
||||
return True, 0
|
||||
|
||||
def _get_reset_time(self) -> str:
|
||||
"""
|
||||
Get the reset time for the current hour bucket (top of next hour).
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
next_hour = (now + timedelta(hours=1)).replace(minute=0, second=0, microsecond=0)
|
||||
|
||||
return next_hour.isoformat()
|
||||
|
||||
|
||||
async def get_rate_limit_middleware(app):
|
||||
"""
|
||||
Factory function to create rate limiting middleware with Redis client.
|
||||
"""
|
||||
try:
|
||||
from gateway.app.core.config import settings
|
||||
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
||||
|
||||
logger.info("API rate limiting middleware initialized with Redis")
|
||||
return APIRateLimitMiddleware(app, redis_client=redis_client)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to initialize Redis for rate limiting, middleware will fail open",
|
||||
error=str(e)
|
||||
)
|
||||
return APIRateLimitMiddleware(app, redis_client=None)
|
||||
@@ -48,6 +48,12 @@ async def get_tenant_children(request: Request, tenant_id: str = Path(...)):
|
||||
"""Get tenant children"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/children")
|
||||
|
||||
|
||||
@router.api_route("/bulk-children", methods=["POST", "OPTIONS"])
|
||||
async def proxy_bulk_children(request: Request):
|
||||
"""Proxy bulk children creation requests to tenant service"""
|
||||
return await _proxy_to_tenant_service(request, "/api/v1/tenants/bulk-children")
|
||||
|
||||
@router.api_route("/{tenant_id}/children/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_children(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant children requests to tenant service"""
|
||||
|
||||
Reference in New Issue
Block a user