Files
bakery-ia/gateway/app/middleware/rate_limit.py

94 lines
3.0 KiB
Python
Raw Normal View History

"""
Rate limiting middleware for gateway
"""
import logging
2025-07-17 19:54:04 +02:00
import time
from fastapi import Request
from fastapi.responses import JSONResponse
2025-07-17 19:54:04 +02:00
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from typing import Dict, Optional
import asyncio
logger = logging.getLogger(__name__)
2025-07-17 19:54:04 +02:00
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware class"""
2025-07-17 19:54:04 +02:00
def __init__(self, app, calls_per_minute: int = 60):
super().__init__(app)
self.calls_per_minute = calls_per_minute
self.requests: Dict[str, list] = {}
self._cleanup_task = None
2025-07-17 19:54:04 +02:00
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with rate limiting"""
# Skip rate limiting for health checks
if request.url.path in ["/health", "/metrics"]:
return await call_next(request)
# Get client identifier
client_id = self._get_client_id(request)
# Check rate limit
if self._is_rate_limited(client_id):
logger.warning(f"Rate limit exceeded for client: {client_id}")
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"}
)
# Record request
self._record_request(client_id)
# Process request
return await call_next(request)
2025-07-17 19:54:04 +02:00
def _get_client_id(self, request: Request) -> str:
"""Get client identifier"""
# Try to get user ID from state (if authenticated)
if hasattr(request.state, 'user') and request.state.user:
return f"user:{request.state.user.get('user_id', 'unknown')}"
# Fall back to IP address
return f"ip:{request.client.host if request.client else 'unknown'}"
2025-07-17 19:54:04 +02:00
def _is_rate_limited(self, client_id: str) -> bool:
"""Check if client is rate limited"""
now = time.time()
minute_ago = now - 60
2025-07-17 19:54:04 +02:00
# Get recent requests for this client
if client_id not in self.requests:
return False
2025-07-17 19:54:04 +02:00
# Filter requests from last minute
recent_requests = [
req_time for req_time in self.requests[client_id]
if req_time > minute_ago
]
2025-07-17 19:54:04 +02:00
# Update the list
self.requests[client_id] = recent_requests
# Check if limit exceeded
return len(recent_requests) >= self.calls_per_minute
def _record_request(self, client_id: str):
"""Record a request for rate limiting"""
now = time.time()
if client_id not in self.requests:
self.requests[client_id] = []
2025-07-17 19:54:04 +02:00
self.requests[client_id].append(now)
2025-07-17 19:54:04 +02:00
# Keep only last minute of requests
minute_ago = now - 60
self.requests[client_id] = [
req_time for req_time in self.requests[client_id]
if req_time > minute_ago
]