Fix gateway
This commit is contained in:
@@ -3,8 +3,10 @@ Authentication middleware for gateway
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import Request, HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
import httpx
|
||||
from typing import Optional
|
||||
|
||||
@@ -28,74 +30,77 @@ PUBLIC_ROUTES = [
|
||||
"/api/v1/auth/refresh"
|
||||
]
|
||||
|
||||
async def auth_middleware(request: Request, call_next):
|
||||
"""Authentication middleware"""
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""Authentication middleware class"""
|
||||
|
||||
# Check if route requires authentication
|
||||
if _is_public_route(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Get token from header
|
||||
token = _extract_token(request)
|
||||
if not token:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Authentication required"}
|
||||
)
|
||||
|
||||
# Verify token
|
||||
try:
|
||||
# First try to verify token locally
|
||||
payload = jwt_handler.verify_token(token)
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with authentication"""
|
||||
|
||||
if payload:
|
||||
# Add user info to request state
|
||||
request.state.user = payload
|
||||
# Check if route requires authentication
|
||||
if self._is_public_route(request.url.path):
|
||||
return await call_next(request)
|
||||
else:
|
||||
# Token invalid or expired, verify with auth service
|
||||
user_info = await _verify_with_auth_service(token)
|
||||
if user_info:
|
||||
request.state.user = user_info
|
||||
|
||||
# Get token from header
|
||||
token = self._extract_token(request)
|
||||
if not token:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Authentication required"}
|
||||
)
|
||||
|
||||
# Verify token
|
||||
try:
|
||||
# First try to verify token locally
|
||||
payload = jwt_handler.verify_token(token)
|
||||
|
||||
if payload:
|
||||
# Add user info to request state
|
||||
request.state.user = payload
|
||||
return await call_next(request)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid or expired token"}
|
||||
# Token invalid or expired, verify with auth service
|
||||
user_info = await self._verify_with_auth_service(token)
|
||||
if user_info:
|
||||
request.state.user = user_info
|
||||
return await call_next(request)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid or expired token"}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Authentication failed"}
|
||||
)
|
||||
|
||||
def _is_public_route(self, path: str) -> bool:
|
||||
"""Check if route is public"""
|
||||
return any(path.startswith(route) for route in PUBLIC_ROUTES)
|
||||
|
||||
def _extract_token(self, request: Request) -> Optional[str]:
|
||||
"""Extract JWT token from request"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header.split(" ")[1]
|
||||
return None
|
||||
|
||||
async def _verify_with_auth_service(self, token: str) -> Optional[dict]:
|
||||
"""Verify token with auth service"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.post(
|
||||
f"{settings.AUTH_SERVICE_URL}/verify",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Authentication failed"}
|
||||
)
|
||||
|
||||
def _is_public_route(path: str) -> bool:
|
||||
"""Check if route is public"""
|
||||
return any(path.startswith(route) for route in PUBLIC_ROUTES)
|
||||
|
||||
def _extract_token(request: Request) -> Optional[str]:
|
||||
"""Extract JWT token from request"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header.split(" ")[1]
|
||||
return None
|
||||
|
||||
async def _verify_with_auth_service(token: str) -> Optional[dict]:
|
||||
"""Verify token with auth service"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.post(
|
||||
f"{settings.AUTH_SERVICE_URL}/verify",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Auth service verification failed: {e}")
|
||||
return None
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Auth service verification failed: {e}")
|
||||
return None
|
||||
|
||||
@@ -5,44 +5,53 @@ Logging middleware for gateway
|
||||
import logging
|
||||
import time
|
||||
from fastapi import Request
|
||||
import json
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def logging_middleware(request: Request, call_next):
|
||||
"""Logging middleware"""
|
||||
class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Logging middleware class"""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Log request
|
||||
logger.info(
|
||||
f"Request: {request.method} {request.url.path}",
|
||||
extra={
|
||||
"method": request.method,
|
||||
"url": request.url.path,
|
||||
"query_params": str(request.query_params),
|
||||
"client_host": request.client.host,
|
||||
"user_agent": request.headers.get("user-agent", ""),
|
||||
"request_id": getattr(request.state, 'request_id', None)
|
||||
}
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate duration
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Log response
|
||||
logger.info(
|
||||
f"Response: {response.status_code} in {duration:.3f}s",
|
||||
extra={
|
||||
"status_code": response.status_code,
|
||||
"duration": duration,
|
||||
"method": request.method,
|
||||
"url": request.url.path,
|
||||
"request_id": getattr(request.state, 'request_id', None)
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with logging"""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Generate request ID
|
||||
request_id = str(uuid.uuid4())
|
||||
request.state.request_id = request_id
|
||||
|
||||
# Log request
|
||||
logger.info(
|
||||
f"Request: {request.method} {request.url.path}",
|
||||
extra={
|
||||
"method": request.method,
|
||||
"url": request.url.path,
|
||||
"query_params": str(request.query_params),
|
||||
"client_host": request.client.host if request.client else "unknown",
|
||||
"user_agent": request.headers.get("user-agent", ""),
|
||||
"request_id": request_id
|
||||
}
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate duration
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Log response
|
||||
logger.info(
|
||||
f"Response: {response.status_code} in {duration:.3f}s",
|
||||
extra={
|
||||
"status_code": response.status_code,
|
||||
"duration": duration,
|
||||
"method": request.method,
|
||||
"url": request.url.path,
|
||||
"request_id": request_id
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -3,83 +3,91 @@ Rate limiting middleware for gateway
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import Request, HTTPException
|
||||
import time
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
import redis.asyncio as redis
|
||||
from datetime import datetime, timedelta
|
||||
import hashlib
|
||||
|
||||
from app.core.config import settings
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from typing import Dict, Optional
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis client for rate limiting
|
||||
redis_client = redis.from_url(settings.REDIS_URL)
|
||||
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
"""Rate limiting middleware"""
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting middleware class"""
|
||||
|
||||
# Skip rate limiting for health checks
|
||||
if request.url.path in ["/health", "/metrics"]:
|
||||
def __init__(self, app, calls_per_minute: int = 60):
|
||||
super().__init__(app)
|
||||
self.calls_per_minute = calls_per_minute
|
||||
self.requests: Dict[str, list] = {}
|
||||
self._cleanup_task = None
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with rate limiting"""
|
||||
|
||||
# Skip rate limiting for health checks
|
||||
if request.url.path in ["/health", "/metrics"]:
|
||||
return await call_next(request)
|
||||
|
||||
# Get client identifier
|
||||
client_id = self._get_client_id(request)
|
||||
|
||||
# Check rate limit
|
||||
if self._is_rate_limited(client_id):
|
||||
logger.warning(f"Rate limit exceeded for client: {client_id}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"}
|
||||
)
|
||||
|
||||
# Record request
|
||||
self._record_request(client_id)
|
||||
|
||||
# Process request
|
||||
return await call_next(request)
|
||||
|
||||
# Get client identifier (IP address or user ID)
|
||||
client_id = _get_client_id(request)
|
||||
|
||||
# Check rate limit
|
||||
if await _is_rate_limited(client_id):
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"detail": "Rate limit exceeded",
|
||||
"retry_after": settings.RATE_LIMIT_WINDOW
|
||||
}
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Update rate limit counter
|
||||
await _update_rate_limit(client_id)
|
||||
|
||||
return response
|
||||
|
||||
def _get_client_id(request: Request) -> str:
|
||||
"""Get client identifier for rate limiting"""
|
||||
# Use user ID if authenticated, otherwise use IP
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
return f"user:{request.state.user.get('user_id', 'unknown')}"
|
||||
else:
|
||||
# Hash IP address for privacy
|
||||
ip = request.client.host
|
||||
return f"ip:{hashlib.md5(ip.encode()).hexdigest()}"
|
||||
|
||||
async def _is_rate_limited(client_id: str) -> bool:
|
||||
"""Check if client is rate limited"""
|
||||
try:
|
||||
key = f"rate_limit:{client_id}"
|
||||
current_count = await redis_client.get(key)
|
||||
def _get_client_id(self, request: Request) -> str:
|
||||
"""Get client identifier"""
|
||||
# Try to get user ID from state (if authenticated)
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
return f"user:{request.state.user.get('user_id', 'unknown')}"
|
||||
|
||||
if current_count is None:
|
||||
# Fall back to IP address
|
||||
return f"ip:{request.client.host if request.client else 'unknown'}"
|
||||
|
||||
def _is_rate_limited(self, client_id: str) -> bool:
|
||||
"""Check if client is rate limited"""
|
||||
now = time.time()
|
||||
minute_ago = now - 60
|
||||
|
||||
# Get recent requests for this client
|
||||
if client_id not in self.requests:
|
||||
return False
|
||||
|
||||
return int(current_count) >= settings.RATE_LIMIT_REQUESTS
|
||||
# Filter requests from last minute
|
||||
recent_requests = [
|
||||
req_time for req_time in self.requests[client_id]
|
||||
if req_time > minute_ago
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limit check failed: {e}")
|
||||
return False
|
||||
|
||||
async def _update_rate_limit(client_id: str):
|
||||
"""Update rate limit counter"""
|
||||
try:
|
||||
key = f"rate_limit:{client_id}"
|
||||
# Update the list
|
||||
self.requests[client_id] = recent_requests
|
||||
|
||||
# Increment counter
|
||||
current_count = await redis_client.incr(key)
|
||||
# Check if limit exceeded
|
||||
return len(recent_requests) >= self.calls_per_minute
|
||||
|
||||
def _record_request(self, client_id: str):
|
||||
"""Record a request for rate limiting"""
|
||||
now = time.time()
|
||||
|
||||
# Set TTL on first request
|
||||
if current_count == 1:
|
||||
await redis_client.expire(key, settings.RATE_LIMIT_WINDOW)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limit update failed: {e}")
|
||||
if client_id not in self.requests:
|
||||
self.requests[client_id] = []
|
||||
|
||||
self.requests[client_id].append(now)
|
||||
|
||||
# Keep only last minute of requests
|
||||
minute_ago = now - 60
|
||||
self.requests[client_id] = [
|
||||
req_time for req_time in self.requests[client_id]
|
||||
if req_time > minute_ago
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user