Add improvements
This commit is contained in:
@@ -63,7 +63,9 @@ class UserProxy:
|
||||
target_url = f"{auth_url}/api/v1/auth/{path}"
|
||||
|
||||
# Prepare headers (remove hop-by-hop headers)
|
||||
headers = self._prepare_headers(dict(request.headers))
|
||||
# IMPORTANT: Use request.headers directly to get headers added by middleware
|
||||
# Also check request.state for headers injected by middleware
|
||||
headers = self._prepare_headers(request.headers, request)
|
||||
|
||||
# Get request body
|
||||
body = await request.body()
|
||||
@@ -133,23 +135,64 @@ class UserProxy:
|
||||
# Fall back to configured URL
|
||||
return AUTH_SERVICE_URL
|
||||
|
||||
def _prepare_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
|
||||
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
|
||||
# Remove hop-by-hop headers
|
||||
hop_by_hop_headers = {
|
||||
'connection', 'keep-alive', 'proxy-authenticate',
|
||||
'proxy-authorization', 'te', 'trailers', 'upgrade'
|
||||
}
|
||||
|
||||
filtered_headers = {
|
||||
k: v for k, v in headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
|
||||
|
||||
# Convert headers to dict if it's a Headers object
|
||||
# This ensures we get ALL headers including those added by middleware
|
||||
if hasattr(headers, '_list'):
|
||||
# Get headers from the _list where middleware adds them
|
||||
all_headers_list = headers.__dict__.get('_list', [])
|
||||
|
||||
# Convert to dict for easier processing
|
||||
all_headers = {}
|
||||
for k, v in all_headers_list:
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
|
||||
# Check if headers are missing and try to get them from request.state
|
||||
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
|
||||
# Add missing headers from request.state
|
||||
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
|
||||
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
|
||||
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
|
||||
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
|
||||
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
|
||||
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
|
||||
|
||||
# Add is_demo flag if this is a demo session
|
||||
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
|
||||
all_headers['x-is-demo'] = 'true'
|
||||
|
||||
# Filter out hop-by-hop headers
|
||||
filtered_headers = {
|
||||
k: v for k, v in all_headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
elif hasattr(headers, 'raw'):
|
||||
# FastAPI/Starlette Headers object - use raw to get all headers
|
||||
filtered_headers = {
|
||||
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
|
||||
for k, v in headers.raw
|
||||
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
|
||||
}
|
||||
else:
|
||||
# Already a dict
|
||||
filtered_headers = {
|
||||
k: v for k, v in headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
|
||||
# Add gateway identifier
|
||||
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
|
||||
filtered_headers['X-Gateway-Version'] = '1.0.0'
|
||||
|
||||
|
||||
return filtered_headers
|
||||
|
||||
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||
|
||||
Reference in New Issue
Block a user