""" Rate limiting middleware for gateway """ import logging import time from fastapi import Request from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response from typing import Dict, Optional import asyncio logger = logging.getLogger(__name__) class RateLimitMiddleware(BaseHTTPMiddleware): """Rate limiting middleware class""" def __init__(self, app, calls_per_minute: int = 60): super().__init__(app) self.calls_per_minute = calls_per_minute self.requests: Dict[str, list] = {} self._cleanup_task = None async def dispatch(self, request: Request, call_next) -> Response: """Process request with rate limiting""" # Skip rate limiting for health checks if request.url.path in ["/health", "/metrics"]: return await call_next(request) # Get client identifier client_id = self._get_client_id(request) # Check rate limit if self._is_rate_limited(client_id): logger.warning(f"Rate limit exceeded for client: {client_id}") return JSONResponse( status_code=429, content={"detail": "Rate limit exceeded"} ) # Record request self._record_request(client_id) # Process request return await call_next(request) def _get_client_id(self, request: Request) -> str: """Get client identifier""" # Try to get user ID from state (if authenticated) if hasattr(request.state, 'user') and request.state.user: return f"user:{request.state.user.get('user_id', 'unknown')}" # Fall back to IP address return f"ip:{request.client.host if request.client else 'unknown'}" def _is_rate_limited(self, client_id: str) -> bool: """Check if client is rate limited""" now = time.time() minute_ago = now - 60 # Get recent requests for this client if client_id not in self.requests: return False # Filter requests from last minute recent_requests = [ req_time for req_time in self.requests[client_id] if req_time > minute_ago ] # Update the list self.requests[client_id] = recent_requests # Check if limit exceeded return len(recent_requests) >= self.calls_per_minute def _record_request(self, client_id: str): """Record a request for rate limiting""" now = time.time() if client_id not in self.requests: self.requests[client_id] = [] self.requests[client_id].append(now) # Keep only last minute of requests minute_ago = now - 60 self.requests[client_id] = [ req_time for req_time in self.requests[client_id] if req_time > minute_ago ]