2025-07-17 13:09:24 +02:00
|
|
|
"""
|
|
|
|
|
Rate limiting middleware for gateway
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import logging
|
2025-07-17 19:54:04 +02:00
|
|
|
import time
|
|
|
|
|
from fastapi import Request
|
2025-07-17 13:09:24 +02:00
|
|
|
from fastapi.responses import JSONResponse
|
2025-07-17 19:54:04 +02:00
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
|
from starlette.responses import Response
|
|
|
|
|
from typing import Dict, Optional
|
|
|
|
|
import asyncio
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
"""Rate limiting middleware class"""
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
def __init__(self, app, calls_per_minute: int = 60):
|
|
|
|
|
super().__init__(app)
|
|
|
|
|
self.calls_per_minute = calls_per_minute
|
|
|
|
|
self.requests: Dict[str, list] = {}
|
|
|
|
|
self._cleanup_task = None
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|
|
|
|
"""Process request with rate limiting"""
|
|
|
|
|
|
|
|
|
|
# Skip rate limiting for health checks
|
|
|
|
|
if request.url.path in ["/health", "/metrics"]:
|
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
|
|
|
|
# Get client identifier
|
|
|
|
|
client_id = self._get_client_id(request)
|
|
|
|
|
|
|
|
|
|
# Check rate limit
|
|
|
|
|
if self._is_rate_limited(client_id):
|
|
|
|
|
logger.warning(f"Rate limit exceeded for client: {client_id}")
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
status_code=429,
|
|
|
|
|
content={"detail": "Rate limit exceeded"}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Record request
|
|
|
|
|
self._record_request(client_id)
|
|
|
|
|
|
|
|
|
|
# Process request
|
|
|
|
|
return await call_next(request)
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
def _get_client_id(self, request: Request) -> str:
|
|
|
|
|
"""Get client identifier"""
|
|
|
|
|
# Try to get user ID from state (if authenticated)
|
|
|
|
|
if hasattr(request.state, 'user') and request.state.user:
|
|
|
|
|
return f"user:{request.state.user.get('user_id', 'unknown')}"
|
|
|
|
|
|
|
|
|
|
# Fall back to IP address
|
|
|
|
|
return f"ip:{request.client.host if request.client else 'unknown'}"
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
def _is_rate_limited(self, client_id: str) -> bool:
|
|
|
|
|
"""Check if client is rate limited"""
|
|
|
|
|
now = time.time()
|
|
|
|
|
minute_ago = now - 60
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
# Get recent requests for this client
|
|
|
|
|
if client_id not in self.requests:
|
2025-07-17 13:09:24 +02:00
|
|
|
return False
|
|
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
# Filter requests from last minute
|
|
|
|
|
recent_requests = [
|
|
|
|
|
req_time for req_time in self.requests[client_id]
|
|
|
|
|
if req_time > minute_ago
|
|
|
|
|
]
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
# Update the list
|
|
|
|
|
self.requests[client_id] = recent_requests
|
|
|
|
|
|
|
|
|
|
# Check if limit exceeded
|
|
|
|
|
return len(recent_requests) >= self.calls_per_minute
|
|
|
|
|
|
|
|
|
|
def _record_request(self, client_id: str):
|
|
|
|
|
"""Record a request for rate limiting"""
|
|
|
|
|
now = time.time()
|
|
|
|
|
|
|
|
|
|
if client_id not in self.requests:
|
|
|
|
|
self.requests[client_id] = []
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
self.requests[client_id].append(now)
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-17 19:54:04 +02:00
|
|
|
# Keep only last minute of requests
|
|
|
|
|
minute_ago = now - 60
|
|
|
|
|
self.requests[client_id] = [
|
|
|
|
|
req_time for req_time in self.requests[client_id]
|
|
|
|
|
if req_time > minute_ago
|
|
|
|
|
]
|