345 lines
14 KiB
Python
345 lines
14 KiB
Python
"""
|
|
Unified Header Management System for API Gateway
|
|
Centralized header injection, forwarding, and validation
|
|
"""
|
|
|
|
import structlog
|
|
from fastapi import Request
|
|
from typing import Dict, Any, Optional, List
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class HeaderManager:
|
|
"""
|
|
Centralized header management for consistent header handling across gateway
|
|
"""
|
|
|
|
# Standard header names (lowercase for consistency)
|
|
STANDARD_HEADERS = {
|
|
'user_id': 'x-user-id',
|
|
'user_email': 'x-user-email',
|
|
'user_role': 'x-user-role',
|
|
'user_type': 'x-user-type',
|
|
'service_name': 'x-service-name',
|
|
'tenant_id': 'x-tenant-id',
|
|
'subscription_tier': 'x-subscription-tier',
|
|
'subscription_status': 'x-subscription-status',
|
|
'is_demo': 'x-is-demo',
|
|
'demo_session_id': 'x-demo-session-id',
|
|
'demo_account_type': 'x-demo-account-type',
|
|
'tenant_access_type': 'x-tenant-access-type',
|
|
'can_view_children': 'x-can-view-children',
|
|
'parent_tenant_id': 'x-parent-tenant-id',
|
|
'forwarded_by': 'x-forwarded-by',
|
|
'request_id': 'x-request-id'
|
|
}
|
|
|
|
# Headers that should be sanitized/removed from incoming requests
|
|
SANITIZED_HEADERS = [
|
|
'x-subscription-',
|
|
'x-user-',
|
|
'x-tenant-',
|
|
'x-demo-',
|
|
'x-forwarded-by'
|
|
]
|
|
|
|
# Headers that should be forwarded to downstream services
|
|
FORWARDABLE_HEADERS = [
|
|
'authorization',
|
|
'content-type',
|
|
'accept',
|
|
'accept-language',
|
|
'user-agent',
|
|
'x-internal-service' # Required for internal service-to-service ML/alert triggers
|
|
]
|
|
|
|
def __init__(self):
|
|
self._initialized = False
|
|
|
|
def initialize(self):
|
|
"""Initialize header manager"""
|
|
if not self._initialized:
|
|
logger.info("HeaderManager initialized")
|
|
self._initialized = True
|
|
|
|
def sanitize_incoming_headers(self, request: Request) -> None:
|
|
"""
|
|
Remove sensitive headers from incoming request to prevent spoofing
|
|
"""
|
|
if not hasattr(request.headers, '_list'):
|
|
return
|
|
|
|
# Filter out headers that start with sanitized prefixes
|
|
sanitized_headers = [
|
|
(k, v) for k, v in request.headers.raw
|
|
if not any(k.decode().lower().startswith(prefix.lower())
|
|
for prefix in self.SANITIZED_HEADERS)
|
|
]
|
|
|
|
request.headers.__dict__["_list"] = sanitized_headers
|
|
logger.debug("Sanitized incoming headers")
|
|
|
|
def inject_context_headers(self, request: Request, user_context: Dict[str, Any],
|
|
tenant_id: Optional[str] = None) -> Dict[str, str]:
|
|
"""
|
|
Inject standardized context headers into request
|
|
Returns dict of injected headers for reference
|
|
"""
|
|
injected_headers = {}
|
|
|
|
# Ensure headers list exists
|
|
if not hasattr(request.headers, '_list'):
|
|
request.headers.__dict__["_list"] = []
|
|
|
|
# Store headers in request.state for cross-middleware access
|
|
request.state.injected_headers = {}
|
|
|
|
# User context headers
|
|
if user_context.get('user_id'):
|
|
header_name = self.STANDARD_HEADERS['user_id']
|
|
header_value = str(user_context['user_id'])
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
if user_context.get('email'):
|
|
header_name = self.STANDARD_HEADERS['user_email']
|
|
header_value = str(user_context['email'])
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
if user_context.get('role'):
|
|
header_name = self.STANDARD_HEADERS['user_role']
|
|
header_value = str(user_context['role'])
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
# User type (service vs regular user)
|
|
if user_context.get('type'):
|
|
header_name = self.STANDARD_HEADERS['user_type']
|
|
header_value = str(user_context['type'])
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
# Service name for service tokens
|
|
if user_context.get('service'):
|
|
header_name = self.STANDARD_HEADERS['service_name']
|
|
header_value = str(user_context['service'])
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
# Tenant context
|
|
if tenant_id:
|
|
header_name = self.STANDARD_HEADERS['tenant_id']
|
|
header_value = str(tenant_id)
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
# Subscription context
|
|
if user_context.get('subscription_tier'):
|
|
header_name = self.STANDARD_HEADERS['subscription_tier']
|
|
header_value = str(user_context['subscription_tier'])
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
if user_context.get('subscription_status'):
|
|
header_name = self.STANDARD_HEADERS['subscription_status']
|
|
header_value = str(user_context['subscription_status'])
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
# Demo session context
|
|
is_demo = user_context.get('is_demo', False)
|
|
if is_demo:
|
|
header_name = self.STANDARD_HEADERS['is_demo']
|
|
header_value = "true"
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
if user_context.get('demo_session_id'):
|
|
header_name = self.STANDARD_HEADERS['demo_session_id']
|
|
header_value = str(user_context['demo_session_id'])
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
if user_context.get('demo_account_type'):
|
|
header_name = self.STANDARD_HEADERS['demo_account_type']
|
|
header_value = str(user_context['demo_account_type'])
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
# Hierarchical access context
|
|
if tenant_id:
|
|
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
|
|
can_view_children = getattr(request.state, 'can_view_children', False)
|
|
|
|
header_name = self.STANDARD_HEADERS['tenant_access_type']
|
|
header_value = str(tenant_access_type)
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
header_name = self.STANDARD_HEADERS['can_view_children']
|
|
header_value = str(can_view_children).lower()
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
# Parent tenant ID if hierarchical access
|
|
parent_tenant_id = getattr(request.state, 'parent_tenant_id', None)
|
|
if parent_tenant_id:
|
|
header_name = self.STANDARD_HEADERS['parent_tenant_id']
|
|
header_value = str(parent_tenant_id)
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
# Gateway identification
|
|
header_name = self.STANDARD_HEADERS['forwarded_by']
|
|
header_value = "bakery-gateway"
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
# Request ID if available
|
|
request_id = getattr(request.state, 'request_id', None)
|
|
if request_id:
|
|
header_name = self.STANDARD_HEADERS['request_id']
|
|
header_value = str(request_id)
|
|
self._add_header(request, header_name, header_value)
|
|
injected_headers[header_name] = header_value
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
logger.info("🔧 Injected context headers",
|
|
user_id=user_context.get('user_id'),
|
|
user_type=user_context.get('type', ''),
|
|
service_name=user_context.get('service', ''),
|
|
role=user_context.get('role', ''),
|
|
tenant_id=tenant_id,
|
|
is_demo=is_demo,
|
|
demo_session_id=user_context.get('demo_session_id', ''),
|
|
path=request.url.path)
|
|
|
|
return injected_headers
|
|
|
|
def _add_header(self, request: Request, header_name: str, header_value: str) -> None:
|
|
"""
|
|
Safely add header to request
|
|
"""
|
|
try:
|
|
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
|
|
except Exception as e:
|
|
logger.warning(f"Failed to add header {header_name}: {e}")
|
|
|
|
def get_forwardable_headers(self, request: Request) -> Dict[str, str]:
|
|
"""
|
|
Get headers that should be forwarded to downstream services
|
|
Includes both original request headers and injected context headers
|
|
"""
|
|
forwardable_headers = {}
|
|
|
|
# Add forwardable original headers
|
|
for header_name in self.FORWARDABLE_HEADERS:
|
|
header_value = request.headers.get(header_name)
|
|
if header_value:
|
|
forwardable_headers[header_name] = header_value
|
|
|
|
# Add injected context headers from request.state
|
|
if hasattr(request.state, 'injected_headers'):
|
|
for header_name, header_value in request.state.injected_headers.items():
|
|
forwardable_headers[header_name] = header_value
|
|
|
|
# Add authorization header if present
|
|
auth_header = request.headers.get('authorization')
|
|
if auth_header:
|
|
forwardable_headers['authorization'] = auth_header
|
|
|
|
return forwardable_headers
|
|
|
|
def get_all_headers_for_proxy(self, request: Request,
|
|
additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
|
"""
|
|
Get complete set of headers for proxying to downstream services
|
|
"""
|
|
headers = self.get_forwardable_headers(request)
|
|
|
|
# Add any additional headers
|
|
if additional_headers:
|
|
headers.update(additional_headers)
|
|
|
|
# Remove host header as it will be set by httpx
|
|
headers.pop('host', None)
|
|
|
|
return headers
|
|
|
|
def validate_required_headers(self, request: Request, required_headers: List[str]) -> bool:
|
|
"""
|
|
Validate that required headers are present
|
|
"""
|
|
missing_headers = []
|
|
|
|
for header_name in required_headers:
|
|
# Check in injected headers first
|
|
if hasattr(request.state, 'injected_headers'):
|
|
if header_name in request.state.injected_headers:
|
|
continue
|
|
|
|
# Check in request headers
|
|
if request.headers.get(header_name):
|
|
continue
|
|
|
|
missing_headers.append(header_name)
|
|
|
|
if missing_headers:
|
|
logger.warning(f"Missing required headers: {missing_headers}")
|
|
return False
|
|
|
|
return True
|
|
|
|
def get_header_value(self, request: Request, header_name: str,
|
|
default: Optional[str] = None) -> Optional[str]:
|
|
"""
|
|
Get header value from either injected headers or request headers
|
|
"""
|
|
# Check injected headers first
|
|
if hasattr(request.state, 'injected_headers'):
|
|
if header_name in request.state.injected_headers:
|
|
return request.state.injected_headers[header_name]
|
|
|
|
# Check request headers
|
|
return request.headers.get(header_name, default)
|
|
|
|
def add_header_for_middleware(self, request: Request, header_name: str, header_value: str) -> None:
|
|
"""
|
|
Allow middleware to add headers to the unified header system
|
|
This ensures all headers are available for proxying
|
|
"""
|
|
# Ensure injected_headers exists
|
|
if not hasattr(request.state, 'injected_headers'):
|
|
request.state.injected_headers = {}
|
|
|
|
# Add header to injected_headers
|
|
request.state.injected_headers[header_name] = header_value
|
|
|
|
# Also add to actual request headers for compatibility
|
|
try:
|
|
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
|
|
except Exception as e:
|
|
logger.warning(f"Failed to add header {header_name} to request headers: {e}")
|
|
|
|
logger.debug(f"Middleware added header: {header_name} = {header_value}")
|
|
|
|
|
|
# Global instance for easy access
|
|
header_manager = HeaderManager() |