Fix gateway

This commit is contained in:
Urtzi Alfaro
2025-07-17 19:54:04 +02:00
parent caf7dea73a
commit 654d1c2fe8
4 changed files with 203 additions and 207 deletions

View File

@@ -3,83 +3,91 @@ Rate limiting middleware for gateway
"""
import logging
from fastapi import Request, HTTPException
import time
from fastapi import Request
from fastapi.responses import JSONResponse
import redis.asyncio as redis
from datetime import datetime, timedelta
import hashlib
from app.core.config import settings
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from typing import Dict, Optional
import asyncio
logger = logging.getLogger(__name__)
# Redis client for rate limiting
redis_client = redis.from_url(settings.REDIS_URL)
async def rate_limit_middleware(request: Request, call_next):
"""Rate limiting middleware"""
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware class"""
# Skip rate limiting for health checks
if request.url.path in ["/health", "/metrics"]:
def __init__(self, app, calls_per_minute: int = 60):
super().__init__(app)
self.calls_per_minute = calls_per_minute
self.requests: Dict[str, list] = {}
self._cleanup_task = None
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with rate limiting"""
# Skip rate limiting for health checks
if request.url.path in ["/health", "/metrics"]:
return await call_next(request)
# Get client identifier
client_id = self._get_client_id(request)
# Check rate limit
if self._is_rate_limited(client_id):
logger.warning(f"Rate limit exceeded for client: {client_id}")
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"}
)
# Record request
self._record_request(client_id)
# Process request
return await call_next(request)
# Get client identifier (IP address or user ID)
client_id = _get_client_id(request)
# Check rate limit
if await _is_rate_limited(client_id):
return JSONResponse(
status_code=429,
content={
"detail": "Rate limit exceeded",
"retry_after": settings.RATE_LIMIT_WINDOW
}
)
# Process request
response = await call_next(request)
# Update rate limit counter
await _update_rate_limit(client_id)
return response
def _get_client_id(request: Request) -> str:
"""Get client identifier for rate limiting"""
# Use user ID if authenticated, otherwise use IP
if hasattr(request.state, 'user') and request.state.user:
return f"user:{request.state.user.get('user_id', 'unknown')}"
else:
# Hash IP address for privacy
ip = request.client.host
return f"ip:{hashlib.md5(ip.encode()).hexdigest()}"
async def _is_rate_limited(client_id: str) -> bool:
"""Check if client is rate limited"""
try:
key = f"rate_limit:{client_id}"
current_count = await redis_client.get(key)
def _get_client_id(self, request: Request) -> str:
"""Get client identifier"""
# Try to get user ID from state (if authenticated)
if hasattr(request.state, 'user') and request.state.user:
return f"user:{request.state.user.get('user_id', 'unknown')}"
if current_count is None:
# Fall back to IP address
return f"ip:{request.client.host if request.client else 'unknown'}"
def _is_rate_limited(self, client_id: str) -> bool:
"""Check if client is rate limited"""
now = time.time()
minute_ago = now - 60
# Get recent requests for this client
if client_id not in self.requests:
return False
return int(current_count) >= settings.RATE_LIMIT_REQUESTS
# Filter requests from last minute
recent_requests = [
req_time for req_time in self.requests[client_id]
if req_time > minute_ago
]
except Exception as e:
logger.error(f"Rate limit check failed: {e}")
return False
async def _update_rate_limit(client_id: str):
"""Update rate limit counter"""
try:
key = f"rate_limit:{client_id}"
# Update the list
self.requests[client_id] = recent_requests
# Increment counter
current_count = await redis_client.incr(key)
# Check if limit exceeded
return len(recent_requests) >= self.calls_per_minute
def _record_request(self, client_id: str):
"""Record a request for rate limiting"""
now = time.time()
# Set TTL on first request
if current_count == 1:
await redis_client.expire(key, settings.RATE_LIMIT_WINDOW)
except Exception as e:
logger.error(f"Rate limit update failed: {e}")
if client_id not in self.requests:
self.requests[client_id] = []
self.requests[client_id].append(now)
# Keep only last minute of requests
minute_ago = now - 60
self.requests[client_id] = [
req_time for req_time in self.requests[client_id]
if req_time > minute_ago
]