Initial commit - production deployment
This commit is contained in:
93
gateway/app/middleware/rate_limit.py
Normal file
93
gateway/app/middleware/rate_limit.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Rate limiting middleware for gateway
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from typing import Dict, Optional
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting middleware class"""
|
||||
|
||||
def __init__(self, app, calls_per_minute: int = 60):
|
||||
super().__init__(app)
|
||||
self.calls_per_minute = calls_per_minute
|
||||
self.requests: Dict[str, list] = {}
|
||||
self._cleanup_task = None
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with rate limiting"""
|
||||
|
||||
# Skip rate limiting for health checks
|
||||
if request.url.path in ["/health", "/metrics"]:
|
||||
return await call_next(request)
|
||||
|
||||
# Get client identifier
|
||||
client_id = self._get_client_id(request)
|
||||
|
||||
# Check rate limit
|
||||
if self._is_rate_limited(client_id):
|
||||
logger.warning(f"Rate limit exceeded for client: {client_id}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"}
|
||||
)
|
||||
|
||||
# Record request
|
||||
self._record_request(client_id)
|
||||
|
||||
# Process request
|
||||
return await call_next(request)
|
||||
|
||||
def _get_client_id(self, request: Request) -> str:
|
||||
"""Get client identifier"""
|
||||
# Try to get user ID from state (if authenticated)
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
return f"user:{request.state.user.get('user_id', 'unknown')}"
|
||||
|
||||
# Fall back to IP address
|
||||
return f"ip:{request.client.host if request.client else 'unknown'}"
|
||||
|
||||
def _is_rate_limited(self, client_id: str) -> bool:
|
||||
"""Check if client is rate limited"""
|
||||
now = time.time()
|
||||
minute_ago = now - 60
|
||||
|
||||
# Get recent requests for this client
|
||||
if client_id not in self.requests:
|
||||
return False
|
||||
|
||||
# Filter requests from last minute
|
||||
recent_requests = [
|
||||
req_time for req_time in self.requests[client_id]
|
||||
if req_time > minute_ago
|
||||
]
|
||||
|
||||
# Update the list
|
||||
self.requests[client_id] = recent_requests
|
||||
|
||||
# Check if limit exceeded
|
||||
return len(recent_requests) >= self.calls_per_minute
|
||||
|
||||
def _record_request(self, client_id: str):
|
||||
"""Record a request for rate limiting"""
|
||||
now = time.time()
|
||||
|
||||
if client_id not in self.requests:
|
||||
self.requests[client_id] = []
|
||||
|
||||
self.requests[client_id].append(now)
|
||||
|
||||
# Keep only last minute of requests
|
||||
minute_ago = now - 60
|
||||
self.requests[client_id] = [
|
||||
req_time for req_time in self.requests[client_id]
|
||||
if req_time > minute_ago
|
||||
]
|
||||
Reference in New Issue
Block a user