Add improvements

This commit is contained in:
Urtzi Alfaro
2026-01-12 14:24:14 +01:00
parent 6037faaf8c
commit 230bbe6a19
61 changed files with 1668 additions and 894 deletions

View File

@@ -63,7 +63,9 @@ class AuthProxy:
target_url = f"{auth_url}/{path}"
# Prepare headers (remove hop-by-hop headers)
headers = self._prepare_headers(dict(request.headers))
# IMPORTANT: Use request.headers directly to get headers added by middleware
# Also check request.state for headers injected by middleware
headers = self._prepare_headers(request.headers, request)
# Get request body
body = await request.body()
@@ -133,7 +135,7 @@ class AuthProxy:
# Fall back to configured URL
return AUTH_SERVICE_URL
def _prepare_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
# Remove hop-by-hop headers
hop_by_hop_headers = {
@@ -141,10 +143,94 @@ class AuthProxy:
'proxy-authorization', 'te', 'trailers', 'upgrade'
}
filtered_headers = {
k: v for k, v in headers.items()
if k.lower() not in hop_by_hop_headers
}
# Convert headers to dict - get ALL headers including those added by middleware
# Middleware adds headers to _list, so we need to read from there
logger.debug(f"DEBUG: headers type: {type(headers)}, has _list: {hasattr(headers, '_list')}, has raw: {hasattr(headers, 'raw')}")
logger.debug(f"DEBUG: headers.__dict__ keys: {list(headers.__dict__.keys())}")
logger.debug(f"DEBUG: '_list' in headers.__dict__: {'_list' in headers.__dict__}")
if hasattr(headers, '_list'):
logger.debug(f"DEBUG: Entering _list branch")
logger.debug(f"DEBUG: headers object id: {id(headers)}, _list id: {id(headers.__dict__.get('_list', []))}")
# Get headers from the _list where middleware adds them
all_headers_list = headers.__dict__.get('_list', [])
logger.debug(f"DEBUG: _list length: {len(all_headers_list)}")
# Debug: Show first few headers in the list
debug_headers = []
for i, (k, v) in enumerate(all_headers_list):
if i < 5: # Show first 5 headers for debugging
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
debug_headers.append(f"{key}: {value}")
logger.debug(f"DEBUG: First headers in _list: {debug_headers}")
# Convert to dict for easier processing
all_headers = {}
for k, v in all_headers_list:
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
# Debug: Show if x-user-id and x-is-demo are in the dict
logger.debug(f"DEBUG: x-user-id in all_headers: {'x-user-id' in all_headers}, x-is-demo in all_headers: {'x-is-demo' in all_headers}")
logger.debug(f"DEBUG: all_headers keys: {list(all_headers.keys())[:10]}...") # Show first 10 keys
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
# Check if headers are missing and try to get them from request.state
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
logger.debug(f"DEBUG: Found injected_headers in request.state: {request.state.injected_headers}")
# Add missing headers from request.state
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
logger.debug(f"DEBUG: Added x-user-id from request.state: {all_headers['x-user-id']}")
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
logger.debug(f"DEBUG: Added x-user-email from request.state: {all_headers['x-user-email']}")
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
logger.debug(f"DEBUG: Added x-user-role from request.state: {all_headers['x-user-role']}")
# Add is_demo flag if this is a demo session
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
all_headers['x-is-demo'] = 'true'
logger.debug(f"DEBUG: Added x-is-demo from request.state.is_demo_session")
# Filter out hop-by-hop headers
filtered_headers = {
k: v for k, v in all_headers.items()
if k.lower() not in hop_by_hop_headers
}
elif hasattr(headers, 'raw'):
logger.debug(f"DEBUG: Entering raw branch")
# Filter out hop-by-hop headers
filtered_headers = {
k: v for k, v in all_headers.items()
if k.lower() not in hop_by_hop_headers
}
elif hasattr(headers, 'raw'):
# Fallback to raw headers if _list not available
all_headers = {
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
for k, v in headers.raw
}
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
filtered_headers = {
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
for k, v in headers.raw
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
}
else:
# Handle case where headers is already a dict
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {headers.get('x-user-id', 'MISSING')}, x_is_demo: {headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {headers.get('x-demo-session-id', 'MISSING')}, headers: {list(headers.keys())}")
filtered_headers = {
k: v for k, v in headers.items()
if k.lower() not in hop_by_hop_headers
}
# Add gateway identifier
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'

View File

@@ -110,16 +110,16 @@ async def _proxy_request(request: Request, target_path: str, service_url: str):
headers["x-user-role"] = str(user.get('role', 'user'))
headers["x-user-full-name"] = str(user.get('full_name', ''))
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
# Add subscription context headers
if user.get('subscription_tier'):
headers["x-subscription-tier"] = str(user.get('subscription_tier', ''))
logger.debug(f"Forwarding subscription tier: {user.get('subscription_tier')}")
if user.get('subscription_status'):
headers["x-subscription-status"] = str(user.get('subscription_status', ''))
logger.debug(f"Forwarding subscription status: {user.get('subscription_status')}")
logger.info(f"Forwarding subscription request to {url} with user context: user_id={user.get('user_id')}, email={user.get('email')}, subscription_tier={user.get('subscription_tier', 'not_set')}")
else:
logger.warning(f"No user context available when forwarding subscription request to {url}")

View File

@@ -714,15 +714,15 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
try:
url = f"{service_url}{target_path}"
# Forward headers and add user/tenant context
headers = dict(request.headers)
headers.pop("host", None)
# Add tenant ID header if provided
if tenant_id:
headers["X-Tenant-ID"] = tenant_id
# Add user context headers if available
if hasattr(request.state, 'user') and request.state.user:
user = request.state.user
@@ -731,16 +731,16 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
headers["x-user-role"] = str(user.get('role', 'user'))
headers["x-user-full-name"] = str(user.get('full_name', ''))
headers["x-tenant-id"] = tenant_id or str(user.get('tenant_id', ''))
# Add subscription context headers
if user.get('subscription_tier'):
headers["x-subscription-tier"] = str(user.get('subscription_tier', ''))
logger.debug(f"Forwarding subscription tier: {user.get('subscription_tier')}")
if user.get('subscription_status'):
headers["x-subscription-status"] = str(user.get('subscription_status', ''))
logger.debug(f"Forwarding subscription status: {user.get('subscription_status')}")
# Debug logging
logger.info(f"Forwarding request to {url} with user context: user_id={user.get('user_id')}, email={user.get('email')}, tenant_id={tenant_id}, subscription_tier={user.get('subscription_tier', 'not_set')}")
else:

View File

@@ -63,7 +63,9 @@ class UserProxy:
target_url = f"{auth_url}/api/v1/auth/{path}"
# Prepare headers (remove hop-by-hop headers)
headers = self._prepare_headers(dict(request.headers))
# IMPORTANT: Use request.headers directly to get headers added by middleware
# Also check request.state for headers injected by middleware
headers = self._prepare_headers(request.headers, request)
# Get request body
body = await request.body()
@@ -133,23 +135,64 @@ class UserProxy:
# Fall back to configured URL
return AUTH_SERVICE_URL
def _prepare_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
# Remove hop-by-hop headers
hop_by_hop_headers = {
'connection', 'keep-alive', 'proxy-authenticate',
'proxy-authorization', 'te', 'trailers', 'upgrade'
}
filtered_headers = {
k: v for k, v in headers.items()
if k.lower() not in hop_by_hop_headers
}
# Convert headers to dict if it's a Headers object
# This ensures we get ALL headers including those added by middleware
if hasattr(headers, '_list'):
# Get headers from the _list where middleware adds them
all_headers_list = headers.__dict__.get('_list', [])
# Convert to dict for easier processing
all_headers = {}
for k, v in all_headers_list:
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
# Check if headers are missing and try to get them from request.state
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
# Add missing headers from request.state
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
# Add is_demo flag if this is a demo session
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
all_headers['x-is-demo'] = 'true'
# Filter out hop-by-hop headers
filtered_headers = {
k: v for k, v in all_headers.items()
if k.lower() not in hop_by_hop_headers
}
elif hasattr(headers, 'raw'):
# FastAPI/Starlette Headers object - use raw to get all headers
filtered_headers = {
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
for k, v in headers.raw
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
}
else:
# Already a dict
filtered_headers = {
k: v for k, v in headers.items()
if k.lower() not in hop_by_hop_headers
}
# Add gateway identifier
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
filtered_headers['X-Gateway-Version'] = '1.0.0'
return filtered_headers
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]: