85 lines
2.5 KiB
Python
85 lines
2.5 KiB
Python
"""
|
|
Rate limiting middleware for gateway
|
|
"""
|
|
|
|
import logging
|
|
from fastapi import Request, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
import redis.asyncio as redis
|
|
from datetime import datetime, timedelta
|
|
import hashlib
|
|
|
|
from app.core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Redis client for rate limiting
|
|
redis_client = redis.from_url(settings.REDIS_URL)
|
|
|
|
async def rate_limit_middleware(request: Request, call_next):
|
|
"""Rate limiting middleware"""
|
|
|
|
# Skip rate limiting for health checks
|
|
if request.url.path in ["/health", "/metrics"]:
|
|
return await call_next(request)
|
|
|
|
# Get client identifier (IP address or user ID)
|
|
client_id = _get_client_id(request)
|
|
|
|
# Check rate limit
|
|
if await _is_rate_limited(client_id):
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"detail": "Rate limit exceeded",
|
|
"retry_after": settings.RATE_LIMIT_WINDOW
|
|
}
|
|
)
|
|
|
|
# Process request
|
|
response = await call_next(request)
|
|
|
|
# Update rate limit counter
|
|
await _update_rate_limit(client_id)
|
|
|
|
return response
|
|
|
|
def _get_client_id(request: Request) -> str:
|
|
"""Get client identifier for rate limiting"""
|
|
# Use user ID if authenticated, otherwise use IP
|
|
if hasattr(request.state, 'user') and request.state.user:
|
|
return f"user:{request.state.user.get('user_id', 'unknown')}"
|
|
else:
|
|
# Hash IP address for privacy
|
|
ip = request.client.host
|
|
return f"ip:{hashlib.md5(ip.encode()).hexdigest()}"
|
|
|
|
async def _is_rate_limited(client_id: str) -> bool:
|
|
"""Check if client is rate limited"""
|
|
try:
|
|
key = f"rate_limit:{client_id}"
|
|
current_count = await redis_client.get(key)
|
|
|
|
if current_count is None:
|
|
return False
|
|
|
|
return int(current_count) >= settings.RATE_LIMIT_REQUESTS
|
|
|
|
except Exception as e:
|
|
logger.error(f"Rate limit check failed: {e}")
|
|
return False
|
|
|
|
async def _update_rate_limit(client_id: str):
|
|
"""Update rate limit counter"""
|
|
try:
|
|
key = f"rate_limit:{client_id}"
|
|
|
|
# Increment counter
|
|
current_count = await redis_client.incr(key)
|
|
|
|
# Set TTL on first request
|
|
if current_count == 1:
|
|
await redis_client.expire(key, settings.RATE_LIMIT_WINDOW)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Rate limit update failed: {e}") |