Add improvements 2
This commit is contained in:
345
gateway/app/core/header_manager.py
Normal file
345
gateway/app/core/header_manager.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
Unified Header Management System for API Gateway
|
||||
Centralized header injection, forwarding, and validation
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import Request
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class HeaderManager:
|
||||
"""
|
||||
Centralized header management for consistent header handling across gateway
|
||||
"""
|
||||
|
||||
# Standard header names (lowercase for consistency)
|
||||
STANDARD_HEADERS = {
|
||||
'user_id': 'x-user-id',
|
||||
'user_email': 'x-user-email',
|
||||
'user_role': 'x-user-role',
|
||||
'user_type': 'x-user-type',
|
||||
'service_name': 'x-service-name',
|
||||
'tenant_id': 'x-tenant-id',
|
||||
'subscription_tier': 'x-subscription-tier',
|
||||
'subscription_status': 'x-subscription-status',
|
||||
'is_demo': 'x-is-demo',
|
||||
'demo_session_id': 'x-demo-session-id',
|
||||
'demo_account_type': 'x-demo-account-type',
|
||||
'tenant_access_type': 'x-tenant-access-type',
|
||||
'can_view_children': 'x-can-view-children',
|
||||
'parent_tenant_id': 'x-parent-tenant-id',
|
||||
'forwarded_by': 'x-forwarded-by',
|
||||
'request_id': 'x-request-id'
|
||||
}
|
||||
|
||||
# Headers that should be sanitized/removed from incoming requests
|
||||
SANITIZED_HEADERS = [
|
||||
'x-subscription-',
|
||||
'x-user-',
|
||||
'x-tenant-',
|
||||
'x-demo-',
|
||||
'x-forwarded-by'
|
||||
]
|
||||
|
||||
# Headers that should be forwarded to downstream services
|
||||
FORWARDABLE_HEADERS = [
|
||||
'authorization',
|
||||
'content-type',
|
||||
'accept',
|
||||
'accept-language',
|
||||
'user-agent',
|
||||
'x-internal-service' # Required for internal service-to-service ML/alert triggers
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._initialized = False
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize header manager"""
|
||||
if not self._initialized:
|
||||
logger.info("HeaderManager initialized")
|
||||
self._initialized = True
|
||||
|
||||
def sanitize_incoming_headers(self, request: Request) -> None:
|
||||
"""
|
||||
Remove sensitive headers from incoming request to prevent spoofing
|
||||
"""
|
||||
if not hasattr(request.headers, '_list'):
|
||||
return
|
||||
|
||||
# Filter out headers that start with sanitized prefixes
|
||||
sanitized_headers = [
|
||||
(k, v) for k, v in request.headers.raw
|
||||
if not any(k.decode().lower().startswith(prefix.lower())
|
||||
for prefix in self.SANITIZED_HEADERS)
|
||||
]
|
||||
|
||||
request.headers.__dict__["_list"] = sanitized_headers
|
||||
logger.debug("Sanitized incoming headers")
|
||||
|
||||
def inject_context_headers(self, request: Request, user_context: Dict[str, Any],
|
||||
tenant_id: Optional[str] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Inject standardized context headers into request
|
||||
Returns dict of injected headers for reference
|
||||
"""
|
||||
injected_headers = {}
|
||||
|
||||
# Ensure headers list exists
|
||||
if not hasattr(request.headers, '_list'):
|
||||
request.headers.__dict__["_list"] = []
|
||||
|
||||
# Store headers in request.state for cross-middleware access
|
||||
request.state.injected_headers = {}
|
||||
|
||||
# User context headers
|
||||
if user_context.get('user_id'):
|
||||
header_name = self.STANDARD_HEADERS['user_id']
|
||||
header_value = str(user_context['user_id'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('email'):
|
||||
header_name = self.STANDARD_HEADERS['user_email']
|
||||
header_value = str(user_context['email'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('role'):
|
||||
header_name = self.STANDARD_HEADERS['user_role']
|
||||
header_value = str(user_context['role'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# User type (service vs regular user)
|
||||
if user_context.get('type'):
|
||||
header_name = self.STANDARD_HEADERS['user_type']
|
||||
header_value = str(user_context['type'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Service name for service tokens
|
||||
if user_context.get('service'):
|
||||
header_name = self.STANDARD_HEADERS['service_name']
|
||||
header_value = str(user_context['service'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Tenant context
|
||||
if tenant_id:
|
||||
header_name = self.STANDARD_HEADERS['tenant_id']
|
||||
header_value = str(tenant_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Subscription context
|
||||
if user_context.get('subscription_tier'):
|
||||
header_name = self.STANDARD_HEADERS['subscription_tier']
|
||||
header_value = str(user_context['subscription_tier'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('subscription_status'):
|
||||
header_name = self.STANDARD_HEADERS['subscription_status']
|
||||
header_value = str(user_context['subscription_status'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Demo session context
|
||||
is_demo = user_context.get('is_demo', False)
|
||||
if is_demo:
|
||||
header_name = self.STANDARD_HEADERS['is_demo']
|
||||
header_value = "true"
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('demo_session_id'):
|
||||
header_name = self.STANDARD_HEADERS['demo_session_id']
|
||||
header_value = str(user_context['demo_session_id'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('demo_account_type'):
|
||||
header_name = self.STANDARD_HEADERS['demo_account_type']
|
||||
header_value = str(user_context['demo_account_type'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Hierarchical access context
|
||||
if tenant_id:
|
||||
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
|
||||
can_view_children = getattr(request.state, 'can_view_children', False)
|
||||
|
||||
header_name = self.STANDARD_HEADERS['tenant_access_type']
|
||||
header_value = str(tenant_access_type)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
header_name = self.STANDARD_HEADERS['can_view_children']
|
||||
header_value = str(can_view_children).lower()
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Parent tenant ID if hierarchical access
|
||||
parent_tenant_id = getattr(request.state, 'parent_tenant_id', None)
|
||||
if parent_tenant_id:
|
||||
header_name = self.STANDARD_HEADERS['parent_tenant_id']
|
||||
header_value = str(parent_tenant_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Gateway identification
|
||||
header_name = self.STANDARD_HEADERS['forwarded_by']
|
||||
header_value = "bakery-gateway"
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Request ID if available
|
||||
request_id = getattr(request.state, 'request_id', None)
|
||||
if request_id:
|
||||
header_name = self.STANDARD_HEADERS['request_id']
|
||||
header_value = str(request_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
logger.info("🔧 Injected context headers",
|
||||
user_id=user_context.get('user_id'),
|
||||
user_type=user_context.get('type', ''),
|
||||
service_name=user_context.get('service', ''),
|
||||
role=user_context.get('role', ''),
|
||||
tenant_id=tenant_id,
|
||||
is_demo=is_demo,
|
||||
demo_session_id=user_context.get('demo_session_id', ''),
|
||||
path=request.url.path)
|
||||
|
||||
return injected_headers
|
||||
|
||||
def _add_header(self, request: Request, header_name: str, header_value: str) -> None:
|
||||
"""
|
||||
Safely add header to request
|
||||
"""
|
||||
try:
|
||||
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add header {header_name}: {e}")
|
||||
|
||||
def get_forwardable_headers(self, request: Request) -> Dict[str, str]:
|
||||
"""
|
||||
Get headers that should be forwarded to downstream services
|
||||
Includes both original request headers and injected context headers
|
||||
"""
|
||||
forwardable_headers = {}
|
||||
|
||||
# Add forwardable original headers
|
||||
for header_name in self.FORWARDABLE_HEADERS:
|
||||
header_value = request.headers.get(header_name)
|
||||
if header_value:
|
||||
forwardable_headers[header_name] = header_value
|
||||
|
||||
# Add injected context headers from request.state
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
for header_name, header_value in request.state.injected_headers.items():
|
||||
forwardable_headers[header_name] = header_value
|
||||
|
||||
# Add authorization header if present
|
||||
auth_header = request.headers.get('authorization')
|
||||
if auth_header:
|
||||
forwardable_headers['authorization'] = auth_header
|
||||
|
||||
return forwardable_headers
|
||||
|
||||
def get_all_headers_for_proxy(self, request: Request,
|
||||
additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Get complete set of headers for proxying to downstream services
|
||||
"""
|
||||
headers = self.get_forwardable_headers(request)
|
||||
|
||||
# Add any additional headers
|
||||
if additional_headers:
|
||||
headers.update(additional_headers)
|
||||
|
||||
# Remove host header as it will be set by httpx
|
||||
headers.pop('host', None)
|
||||
|
||||
return headers
|
||||
|
||||
def validate_required_headers(self, request: Request, required_headers: List[str]) -> bool:
|
||||
"""
|
||||
Validate that required headers are present
|
||||
"""
|
||||
missing_headers = []
|
||||
|
||||
for header_name in required_headers:
|
||||
# Check in injected headers first
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
if header_name in request.state.injected_headers:
|
||||
continue
|
||||
|
||||
# Check in request headers
|
||||
if request.headers.get(header_name):
|
||||
continue
|
||||
|
||||
missing_headers.append(header_name)
|
||||
|
||||
if missing_headers:
|
||||
logger.warning(f"Missing required headers: {missing_headers}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_header_value(self, request: Request, header_name: str,
|
||||
default: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Get header value from either injected headers or request headers
|
||||
"""
|
||||
# Check injected headers first
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
if header_name in request.state.injected_headers:
|
||||
return request.state.injected_headers[header_name]
|
||||
|
||||
# Check request headers
|
||||
return request.headers.get(header_name, default)
|
||||
|
||||
def add_header_for_middleware(self, request: Request, header_name: str, header_value: str) -> None:
|
||||
"""
|
||||
Allow middleware to add headers to the unified header system
|
||||
This ensures all headers are available for proxying
|
||||
"""
|
||||
# Ensure injected_headers exists
|
||||
if not hasattr(request.state, 'injected_headers'):
|
||||
request.state.injected_headers = {}
|
||||
|
||||
# Add header to injected_headers
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Also add to actual request headers for compatibility
|
||||
try:
|
||||
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add header {header_name} to request headers: {e}")
|
||||
|
||||
logger.debug(f"Middleware added header: {header_name} = {header_value}")
|
||||
|
||||
|
||||
# Global instance for easy access
|
||||
header_manager = HeaderManager()
|
||||
@@ -16,6 +16,7 @@ from shared.redis_utils import initialize_redis, close_redis, get_redis_client
|
||||
from shared.service_base import StandardFastAPIService
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from app.middleware.request_id import RequestIDMiddleware
|
||||
from app.middleware.auth import AuthMiddleware
|
||||
from app.middleware.logging import LoggingMiddleware
|
||||
@@ -50,6 +51,10 @@ class GatewayService(StandardFastAPIService):
|
||||
"""Custom startup logic for Gateway"""
|
||||
global redis_client
|
||||
|
||||
# Initialize HeaderManager
|
||||
header_manager.initialize()
|
||||
logger.info("HeaderManager initialized")
|
||||
|
||||
# Initialize Redis
|
||||
try:
|
||||
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
|
||||
|
||||
@@ -14,6 +14,7 @@ import httpx
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
|
||||
|
||||
@@ -60,15 +61,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
|
||||
# SECURITY: Remove any incoming x-subscription-* headers
|
||||
# These will be re-injected from verified JWT only
|
||||
sanitized_headers = [
|
||||
(k, v) for k, v in request.headers.raw
|
||||
if not k.decode().lower().startswith('x-subscription-')
|
||||
and not k.decode().lower().startswith('x-user-')
|
||||
and not k.decode().lower().startswith('x-tenant-')
|
||||
]
|
||||
request.headers.__dict__["_list"] = sanitized_headers
|
||||
# SECURITY: Remove any incoming sensitive headers using HeaderManager
|
||||
header_manager.sanitize_incoming_headers(request)
|
||||
|
||||
# Skip authentication for public routes
|
||||
if self._is_public_route(request.url.path):
|
||||
@@ -573,109 +567,13 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
async def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
|
||||
"""
|
||||
Inject user and tenant context headers for downstream services
|
||||
ENHANCED: Added logging to verify header injection
|
||||
Inject user and tenant context headers for downstream services using unified HeaderManager
|
||||
"""
|
||||
# Enhanced logging for debugging
|
||||
logger.info(
|
||||
"🔧 Injecting context headers",
|
||||
user_id=user_context.get("user_id"),
|
||||
user_type=user_context.get("type", ""),
|
||||
service_name=user_context.get("service", ""),
|
||||
role=user_context.get("role", ""),
|
||||
tenant_id=tenant_id,
|
||||
is_demo=user_context.get("is_demo", False),
|
||||
demo_session_id=user_context.get("demo_session_id", ""),
|
||||
path=request.url.path
|
||||
)
|
||||
|
||||
# Add user context headers
|
||||
logger.debug(f"DEBUG: Injecting headers for user: {user_context.get('user_id')}, is_demo: {user_context.get('is_demo', False)}")
|
||||
logger.debug(f"DEBUG: request.headers object id: {id(request.headers)}, _list id: {id(request.headers.__dict__.get('_list', []))}")
|
||||
# Use unified HeaderManager for consistent header injection
|
||||
injected_headers = header_manager.inject_context_headers(request, user_context, tenant_id)
|
||||
|
||||
# Store headers in request.state for cross-middleware access
|
||||
request.state.injected_headers = {
|
||||
"x-user-id": user_context["user_id"],
|
||||
"x-user-email": user_context["email"],
|
||||
"x-user-role": user_context.get("role", "user")
|
||||
}
|
||||
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-id", user_context["user_id"].encode()
|
||||
))
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-email", user_context["email"].encode()
|
||||
))
|
||||
|
||||
user_role = user_context.get("role", "user")
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-role", user_role.encode()
|
||||
))
|
||||
|
||||
user_type = user_context.get("type", "")
|
||||
if user_type:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-type", user_type.encode()
|
||||
))
|
||||
|
||||
service_name = user_context.get("service", "")
|
||||
if service_name:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-service-name", service_name.encode()
|
||||
))
|
||||
|
||||
# Add tenant context if available
|
||||
if tenant_id:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-tenant-id", tenant_id.encode()
|
||||
))
|
||||
|
||||
# Add subscription tier if available
|
||||
subscription_tier = user_context.get("subscription_tier", "")
|
||||
if subscription_tier:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-subscription-tier", subscription_tier.encode()
|
||||
))
|
||||
|
||||
# Add is_demo flag for demo sessions
|
||||
is_demo = user_context.get("is_demo", False)
|
||||
logger.debug(f"DEBUG: is_demo value: {is_demo}, type: {type(is_demo)}")
|
||||
if is_demo:
|
||||
logger.info(f"🎭 Adding demo session headers",
|
||||
demo_session_id=user_context.get("demo_session_id", ""),
|
||||
demo_account_type=user_context.get("demo_account_type", ""),
|
||||
path=request.url.path)
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-is-demo", b"true"
|
||||
))
|
||||
else:
|
||||
logger.debug(f"DEBUG: Not adding demo headers because is_demo is: {is_demo}")
|
||||
|
||||
# Add demo session context headers for backend services
|
||||
demo_session_id = user_context.get("demo_session_id", "")
|
||||
if demo_session_id:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-demo-session-id", demo_session_id.encode()
|
||||
))
|
||||
|
||||
demo_account_type = user_context.get("demo_account_type", "")
|
||||
if demo_account_type:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-demo-account-type", demo_account_type.encode()
|
||||
))
|
||||
|
||||
# Add hierarchical access headers if tenant context exists
|
||||
if tenant_id:
|
||||
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
|
||||
can_view_children = getattr(request.state, 'can_view_children', False)
|
||||
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-tenant-access-type", tenant_access_type.encode()
|
||||
))
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-can-view-children", str(can_view_children).encode()
|
||||
))
|
||||
|
||||
# If this is hierarchical access, include parent tenant ID
|
||||
# Get parent tenant ID from the auth service if available
|
||||
try:
|
||||
@@ -689,17 +587,16 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
hierarchy_data = response.json()
|
||||
parent_tenant_id = hierarchy_data.get("parent_tenant_id")
|
||||
if parent_tenant_id:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-parent-tenant-id", parent_tenant_id.encode()
|
||||
))
|
||||
# Add parent tenant ID using HeaderManager for consistency
|
||||
header_name = header_manager.STANDARD_HEADERS['parent_tenant_id']
|
||||
header_value = str(parent_tenant_id)
|
||||
header_manager.add_header_for_middleware(request, header_name, header_value)
|
||||
logger.info(f"Added parent tenant ID header: {parent_tenant_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get parent tenant ID: {e}")
|
||||
pass
|
||||
|
||||
# Add gateway identification
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-forwarded-by", b"bakery-gateway"
|
||||
))
|
||||
|
||||
return injected_headers
|
||||
|
||||
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
|
||||
"""
|
||||
|
||||
@@ -45,8 +45,17 @@ class APIRateLimitMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
# Get subscription tier
|
||||
subscription_tier = await self._get_subscription_tier(tenant_id, request)
|
||||
# Get subscription tier from headers (added by AuthMiddleware)
|
||||
subscription_tier = request.headers.get("x-subscription-tier")
|
||||
|
||||
if not subscription_tier:
|
||||
# Fallback: get from request state if headers not available
|
||||
subscription_tier = getattr(request.state, "subscription_tier", None)
|
||||
|
||||
if not subscription_tier:
|
||||
# Final fallback: get from tenant service (should rarely happen)
|
||||
subscription_tier = await self._get_subscription_tier(tenant_id, request)
|
||||
logger.warning(f"Subscription tier not found in headers or state, fetched from tenant service: {subscription_tier}")
|
||||
|
||||
# Get quota limit for tier
|
||||
quota_limit = self._get_quota_limit(subscription_tier)
|
||||
|
||||
@@ -9,6 +9,8 @@ from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
@@ -40,11 +42,9 @@ class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
# Bind request ID to structured logger context
|
||||
logger_ctx = logger.bind(request_id=request_id)
|
||||
|
||||
# Inject request ID header for downstream services
|
||||
# This is done by modifying the headers that will be forwarded
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-request-id", request_id.encode()
|
||||
))
|
||||
# Inject request ID header for downstream services using HeaderManager
|
||||
# Note: This runs early in middleware chain, so we use add_header_for_middleware
|
||||
header_manager.add_header_for_middleware(request, "x-request-id", request_id)
|
||||
|
||||
# Log request start
|
||||
logger_ctx.info(
|
||||
|
||||
@@ -15,6 +15,7 @@ import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from app.utils.subscription_error_responses import create_upgrade_required_response
|
||||
|
||||
logger = structlog.get_logger()
|
||||
@@ -178,7 +179,10 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
|
||||
r'/api/v1/subscriptions/.*', # Subscription management itself
|
||||
r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info
|
||||
r'/docs.*',
|
||||
r'/openapi\.json'
|
||||
r'/openapi\.json',
|
||||
# Training monitoring endpoints (WebSocket and status checks)
|
||||
r'/api/v1/tenants/[^/]+/training/jobs/.*/live.*', # WebSocket endpoint
|
||||
r'/api/v1/tenants/[^/]+/training/jobs/.*/status.*', # Status polling endpoint
|
||||
]
|
||||
|
||||
# Skip OPTIONS requests (CORS preflight)
|
||||
@@ -275,21 +279,11 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
|
||||
'current_tier': current_tier
|
||||
}
|
||||
|
||||
# Use the same authentication pattern as gateway routes for fallback
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
|
||||
# Use unified HeaderManager for consistent header handling
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Extract user_id for logging (fallback path)
|
||||
user_id = 'unknown'
|
||||
# Add user context headers if available
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
user = request.state.user
|
||||
user_id = str(user.get('user_id', 'unknown'))
|
||||
headers["x-user-id"] = user_id
|
||||
headers["x-user-email"] = str(user.get('email', ''))
|
||||
headers["x-user-role"] = str(user.get('role', 'user'))
|
||||
headers["x-user-full-name"] = str(user.get('full_name', ''))
|
||||
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
|
||||
user_id = header_manager.get_header_value(request, 'x-user-id', 'unknown')
|
||||
|
||||
# Call tenant service fast tier endpoint with caching (fallback for old tokens)
|
||||
timeout_config = httpx.Timeout(
|
||||
|
||||
@@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse
|
||||
from typing import Dict, Any
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from app.core.service_discovery import ServiceDiscovery
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
@@ -136,107 +137,32 @@ class AuthProxy:
|
||||
return AUTH_SERVICE_URL
|
||||
|
||||
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
|
||||
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
|
||||
# Remove hop-by-hop headers
|
||||
hop_by_hop_headers = {
|
||||
'connection', 'keep-alive', 'proxy-authenticate',
|
||||
'proxy-authorization', 'te', 'trailers', 'upgrade'
|
||||
}
|
||||
|
||||
# Convert headers to dict - get ALL headers including those added by middleware
|
||||
# Middleware adds headers to _list, so we need to read from there
|
||||
logger.debug(f"DEBUG: headers type: {type(headers)}, has _list: {hasattr(headers, '_list')}, has raw: {hasattr(headers, 'raw')}")
|
||||
logger.debug(f"DEBUG: headers.__dict__ keys: {list(headers.__dict__.keys())}")
|
||||
logger.debug(f"DEBUG: '_list' in headers.__dict__: {'_list' in headers.__dict__}")
|
||||
|
||||
if hasattr(headers, '_list'):
|
||||
logger.debug(f"DEBUG: Entering _list branch")
|
||||
logger.debug(f"DEBUG: headers object id: {id(headers)}, _list id: {id(headers.__dict__.get('_list', []))}")
|
||||
# Get headers from the _list where middleware adds them
|
||||
all_headers_list = headers.__dict__.get('_list', [])
|
||||
logger.debug(f"DEBUG: _list length: {len(all_headers_list)}")
|
||||
|
||||
# Debug: Show first few headers in the list
|
||||
debug_headers = []
|
||||
for i, (k, v) in enumerate(all_headers_list):
|
||||
if i < 5: # Show first 5 headers for debugging
|
||||
"""Prepare headers for forwarding using unified HeaderManager"""
|
||||
# Use unified HeaderManager to get all headers
|
||||
if request:
|
||||
all_headers = header_manager.get_all_headers_for_proxy(request)
|
||||
logger.debug(f"DEBUG: Added headers from HeaderManager: {list(all_headers.keys())}")
|
||||
else:
|
||||
# Fallback: convert headers to dict manually
|
||||
all_headers = {}
|
||||
if hasattr(headers, '_list'):
|
||||
for k, v in headers.__dict__.get('_list', []):
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
debug_headers.append(f"{key}: {value}")
|
||||
logger.debug(f"DEBUG: First headers in _list: {debug_headers}")
|
||||
|
||||
# Convert to dict for easier processing
|
||||
all_headers = {}
|
||||
for k, v in all_headers_list:
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
|
||||
# Debug: Show if x-user-id and x-is-demo are in the dict
|
||||
logger.debug(f"DEBUG: x-user-id in all_headers: {'x-user-id' in all_headers}, x-is-demo in all_headers: {'x-is-demo' in all_headers}")
|
||||
logger.debug(f"DEBUG: all_headers keys: {list(all_headers.keys())[:10]}...") # Show first 10 keys
|
||||
|
||||
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
|
||||
|
||||
# Check if headers are missing and try to get them from request.state
|
||||
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
|
||||
logger.debug(f"DEBUG: Found injected_headers in request.state: {request.state.injected_headers}")
|
||||
# Add missing headers from request.state
|
||||
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
|
||||
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
|
||||
logger.debug(f"DEBUG: Added x-user-id from request.state: {all_headers['x-user-id']}")
|
||||
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
|
||||
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
|
||||
logger.debug(f"DEBUG: Added x-user-email from request.state: {all_headers['x-user-email']}")
|
||||
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
|
||||
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
|
||||
logger.debug(f"DEBUG: Added x-user-role from request.state: {all_headers['x-user-role']}")
|
||||
|
||||
# Add is_demo flag if this is a demo session
|
||||
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
|
||||
all_headers['x-is-demo'] = 'true'
|
||||
logger.debug(f"DEBUG: Added x-is-demo from request.state.is_demo_session")
|
||||
|
||||
# Filter out hop-by-hop headers
|
||||
filtered_headers = {
|
||||
k: v for k, v in all_headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
elif hasattr(headers, 'raw'):
|
||||
logger.debug(f"DEBUG: Entering raw branch")
|
||||
|
||||
# Filter out hop-by-hop headers
|
||||
filtered_headers = {
|
||||
k: v for k, v in all_headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
elif hasattr(headers, 'raw'):
|
||||
# Fallback to raw headers if _list not available
|
||||
all_headers = {
|
||||
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
|
||||
for k, v in headers.raw
|
||||
}
|
||||
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
|
||||
|
||||
filtered_headers = {
|
||||
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
|
||||
for k, v in headers.raw
|
||||
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
|
||||
}
|
||||
else:
|
||||
# Handle case where headers is already a dict
|
||||
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {headers.get('x-user-id', 'MISSING')}, x_is_demo: {headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {headers.get('x-demo-session-id', 'MISSING')}, headers: {list(headers.keys())}")
|
||||
|
||||
filtered_headers = {
|
||||
k: v for k, v in headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
|
||||
# Add gateway identifier
|
||||
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
|
||||
filtered_headers['X-Gateway-Version'] = '1.0.0'
|
||||
|
||||
return filtered_headers
|
||||
all_headers[key] = value
|
||||
elif hasattr(headers, 'raw'):
|
||||
for k, v in headers.raw:
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
else:
|
||||
# Headers is already a dict
|
||||
all_headers = dict(headers)
|
||||
|
||||
# Debug logging
|
||||
logger.info(f"📤 Forwarding headers - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
|
||||
|
||||
return all_headers
|
||||
|
||||
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Prepare response headers"""
|
||||
|
||||
@@ -8,6 +8,7 @@ import httpx
|
||||
import structlog
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -29,12 +30,8 @@ async def proxy_demo_service(path: str, request: Request):
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
# Forward headers (excluding host)
|
||||
headers = {
|
||||
key: value
|
||||
for key, value in request.headers.items()
|
||||
if key.lower() not in ["host", "content-length"]
|
||||
}
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
|
||||
@@ -5,6 +5,7 @@ from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import structlog
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
@@ -26,12 +27,8 @@ async def proxy_geocoding(request: Request, path: str):
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
# Forward headers (excluding host)
|
||||
headers = {
|
||||
key: value
|
||||
for key, value in request.headers.items()
|
||||
if key.lower() not in ["host", "content-length"]
|
||||
}
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Make the proxied request
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
|
||||
@@ -8,6 +8,7 @@ from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import structlog
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
@@ -44,12 +45,8 @@ async def proxy_poi_context(request: Request, path: str):
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
# Copy headers (exclude host and content-length as they'll be set by httpx)
|
||||
headers = {
|
||||
key: value
|
||||
for key, value in request.headers.items()
|
||||
if key.lower() not in ["host", "content-length"]
|
||||
}
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Make the request to the external service
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
|
||||
@@ -8,6 +8,7 @@ import httpx
|
||||
import logging
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
@@ -45,9 +46,8 @@ async def _proxy_to_pos_service(request: Request, target_path: str):
|
||||
try:
|
||||
url = f"{settings.POS_SERVICE_URL}{target_path}"
|
||||
|
||||
# Forward headers
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Add query parameters
|
||||
params = dict(request.query_params)
|
||||
|
||||
@@ -9,6 +9,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
@@ -98,29 +99,13 @@ async def _proxy_request(request: Request, target_path: str, service_url: str):
|
||||
try:
|
||||
url = f"{service_url}{target_path}"
|
||||
|
||||
# Forward headers and add user/tenant context
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
|
||||
# Add user context headers if available
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
user = request.state.user
|
||||
headers["x-user-id"] = str(user.get('user_id', ''))
|
||||
headers["x-user-email"] = str(user.get('email', ''))
|
||||
headers["x-user-role"] = str(user.get('role', 'user'))
|
||||
headers["x-user-full-name"] = str(user.get('full_name', ''))
|
||||
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
|
||||
|
||||
# Add subscription context headers
|
||||
if user.get('subscription_tier'):
|
||||
headers["x-subscription-tier"] = str(user.get('subscription_tier', ''))
|
||||
logger.debug(f"Forwarding subscription tier: {user.get('subscription_tier')}")
|
||||
|
||||
if user.get('subscription_status'):
|
||||
headers["x-subscription-status"] = str(user.get('subscription_status', ''))
|
||||
logger.debug(f"Forwarding subscription status: {user.get('subscription_status')}")
|
||||
|
||||
logger.info(f"Forwarding subscription request to {url} with user context: user_id={user.get('user_id')}, email={user.get('email')}, subscription_tier={user.get('subscription_tier', 'not_set')}")
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Debug logging
|
||||
user_context = getattr(request.state, 'user', None)
|
||||
if user_context:
|
||||
logger.info(f"Forwarding subscription request to {url} with user context: user_id={user_context.get('user_id')}, email={user_context.get('email')}, subscription_tier={user_context.get('subscription_tier', 'not_set')}")
|
||||
else:
|
||||
logger.warning(f"No user context available when forwarding subscription request to {url}")
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
@@ -715,36 +716,18 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
|
||||
try:
|
||||
url = f"{service_url}{target_path}"
|
||||
|
||||
# Forward headers and add user/tenant context
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
|
||||
# Add tenant ID header if provided
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Add tenant ID header if provided (override if needed)
|
||||
if tenant_id:
|
||||
headers["X-Tenant-ID"] = tenant_id
|
||||
|
||||
# Add user context headers if available
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
user = request.state.user
|
||||
headers["x-user-id"] = str(user.get('user_id', ''))
|
||||
headers["x-user-email"] = str(user.get('email', ''))
|
||||
headers["x-user-role"] = str(user.get('role', 'user'))
|
||||
headers["x-user-full-name"] = str(user.get('full_name', ''))
|
||||
headers["x-tenant-id"] = tenant_id or str(user.get('tenant_id', ''))
|
||||
|
||||
# Add subscription context headers
|
||||
if user.get('subscription_tier'):
|
||||
headers["x-subscription-tier"] = str(user.get('subscription_tier', ''))
|
||||
logger.debug(f"Forwarding subscription tier: {user.get('subscription_tier')}")
|
||||
|
||||
if user.get('subscription_status'):
|
||||
headers["x-subscription-status"] = str(user.get('subscription_status', ''))
|
||||
logger.debug(f"Forwarding subscription status: {user.get('subscription_status')}")
|
||||
|
||||
# Debug logging
|
||||
logger.info(f"Forwarding request to {url} with user context: user_id={user.get('user_id')}, email={user.get('email')}, tenant_id={tenant_id}, subscription_tier={user.get('subscription_tier', 'not_set')}")
|
||||
headers["x-tenant-id"] = tenant_id
|
||||
|
||||
# Debug logging
|
||||
user_context = getattr(request.state, 'user', None)
|
||||
if user_context:
|
||||
logger.info(f"Forwarding request to {url} with user context: user_id={user_context.get('user_id')}, email={user_context.get('email')}, tenant_id={tenant_id}, subscription_tier={user_context.get('subscription_tier', 'not_set')}")
|
||||
else:
|
||||
# Debug logging when no user context available
|
||||
logger.warning(f"No user context available when forwarding request to {url}. request.state.user: {getattr(request.state, 'user', 'NOT_SET')}")
|
||||
|
||||
# Get request body if present
|
||||
@@ -782,9 +765,10 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
|
||||
|
||||
logger.info(f"Forwarding multipart request with files={list(files.keys()) if files else None}, data={list(data.keys()) if data else None}")
|
||||
|
||||
# Remove content-type from headers - httpx will set it with new boundary
|
||||
headers.pop("content-type", None)
|
||||
headers.pop("content-length", None)
|
||||
# For multipart requests, we need to get fresh headers since httpx will set content-type
|
||||
# Get all headers again to ensure we have the complete set
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
# httpx will automatically set content-type for multipart, so we don't need to remove it
|
||||
else:
|
||||
# For other content types, use body as before
|
||||
body = await request.body()
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import Dict, Any
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from app.core.service_discovery import ServiceDiscovery
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
@@ -136,64 +137,28 @@ class UserProxy:
|
||||
return AUTH_SERVICE_URL
|
||||
|
||||
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
|
||||
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
|
||||
# Remove hop-by-hop headers
|
||||
hop_by_hop_headers = {
|
||||
'connection', 'keep-alive', 'proxy-authenticate',
|
||||
'proxy-authorization', 'te', 'trailers', 'upgrade'
|
||||
}
|
||||
|
||||
# Convert headers to dict if it's a Headers object
|
||||
# This ensures we get ALL headers including those added by middleware
|
||||
if hasattr(headers, '_list'):
|
||||
# Get headers from the _list where middleware adds them
|
||||
all_headers_list = headers.__dict__.get('_list', [])
|
||||
|
||||
# Convert to dict for easier processing
|
||||
all_headers = {}
|
||||
for k, v in all_headers_list:
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
|
||||
# Check if headers are missing and try to get them from request.state
|
||||
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
|
||||
# Add missing headers from request.state
|
||||
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
|
||||
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
|
||||
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
|
||||
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
|
||||
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
|
||||
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
|
||||
|
||||
# Add is_demo flag if this is a demo session
|
||||
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
|
||||
all_headers['x-is-demo'] = 'true'
|
||||
|
||||
# Filter out hop-by-hop headers
|
||||
filtered_headers = {
|
||||
k: v for k, v in all_headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
elif hasattr(headers, 'raw'):
|
||||
# FastAPI/Starlette Headers object - use raw to get all headers
|
||||
filtered_headers = {
|
||||
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
|
||||
for k, v in headers.raw
|
||||
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
|
||||
}
|
||||
"""Prepare headers for forwarding using unified HeaderManager"""
|
||||
# Use unified HeaderManager to get all headers
|
||||
if request:
|
||||
all_headers = header_manager.get_all_headers_for_proxy(request)
|
||||
else:
|
||||
# Already a dict
|
||||
filtered_headers = {
|
||||
k: v for k, v in headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
|
||||
# Add gateway identifier
|
||||
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
|
||||
filtered_headers['X-Gateway-Version'] = '1.0.0'
|
||||
|
||||
return filtered_headers
|
||||
# Fallback: convert headers to dict manually
|
||||
all_headers = {}
|
||||
if hasattr(headers, '_list'):
|
||||
for k, v in headers.__dict__.get('_list', []):
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
elif hasattr(headers, 'raw'):
|
||||
for k, v in headers.raw:
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
else:
|
||||
# Headers is already a dict
|
||||
all_headers = dict(headers)
|
||||
|
||||
return all_headers
|
||||
|
||||
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Prepare response headers"""
|
||||
|
||||
Reference in New Issue
Block a user