Files
bakery-ia/gateway/app/core/header_manager.py
2026-01-15 20:45:49 +01:00

346 lines
14 KiB
Python

"""
Unified Header Management System for API Gateway
Centralized header injection, forwarding, and validation
"""
import structlog
from fastapi import Request
from typing import Dict, Any, Optional, List
logger = structlog.get_logger()
class HeaderManager:
"""
Centralized header management for consistent header handling across gateway
"""
# Standard header names (lowercase for consistency)
STANDARD_HEADERS = {
'user_id': 'x-user-id',
'user_email': 'x-user-email',
'user_role': 'x-user-role',
'user_type': 'x-user-type',
'service_name': 'x-service-name',
'tenant_id': 'x-tenant-id',
'subscription_tier': 'x-subscription-tier',
'subscription_status': 'x-subscription-status',
'is_demo': 'x-is-demo',
'demo_session_id': 'x-demo-session-id',
'demo_account_type': 'x-demo-account-type',
'tenant_access_type': 'x-tenant-access-type',
'can_view_children': 'x-can-view-children',
'parent_tenant_id': 'x-parent-tenant-id',
'forwarded_by': 'x-forwarded-by',
'request_id': 'x-request-id'
}
# Headers that should be sanitized/removed from incoming requests
SANITIZED_HEADERS = [
'x-subscription-',
'x-user-',
'x-tenant-',
'x-demo-',
'x-forwarded-by'
]
# Headers that should be forwarded to downstream services
FORWARDABLE_HEADERS = [
'authorization',
'content-type',
'accept',
'accept-language',
'user-agent',
'x-internal-service', # Required for internal service-to-service ML/alert triggers
'stripe-signature' # Required for Stripe webhook signature verification
]
def __init__(self):
self._initialized = False
def initialize(self):
"""Initialize header manager"""
if not self._initialized:
logger.info("HeaderManager initialized")
self._initialized = True
def sanitize_incoming_headers(self, request: Request) -> None:
"""
Remove sensitive headers from incoming request to prevent spoofing
"""
if not hasattr(request.headers, '_list'):
return
# Filter out headers that start with sanitized prefixes
sanitized_headers = [
(k, v) for k, v in request.headers.raw
if not any(k.decode().lower().startswith(prefix.lower())
for prefix in self.SANITIZED_HEADERS)
]
request.headers.__dict__["_list"] = sanitized_headers
logger.debug("Sanitized incoming headers")
def inject_context_headers(self, request: Request, user_context: Dict[str, Any],
tenant_id: Optional[str] = None) -> Dict[str, str]:
"""
Inject standardized context headers into request
Returns dict of injected headers for reference
"""
injected_headers = {}
# Ensure headers list exists
if not hasattr(request.headers, '_list'):
request.headers.__dict__["_list"] = []
# Store headers in request.state for cross-middleware access
request.state.injected_headers = {}
# User context headers
if user_context.get('user_id'):
header_name = self.STANDARD_HEADERS['user_id']
header_value = str(user_context['user_id'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('email'):
header_name = self.STANDARD_HEADERS['user_email']
header_value = str(user_context['email'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('role'):
header_name = self.STANDARD_HEADERS['user_role']
header_value = str(user_context['role'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# User type (service vs regular user)
if user_context.get('type'):
header_name = self.STANDARD_HEADERS['user_type']
header_value = str(user_context['type'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Service name for service tokens
if user_context.get('service'):
header_name = self.STANDARD_HEADERS['service_name']
header_value = str(user_context['service'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Tenant context
if tenant_id:
header_name = self.STANDARD_HEADERS['tenant_id']
header_value = str(tenant_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Subscription context
if user_context.get('subscription_tier'):
header_name = self.STANDARD_HEADERS['subscription_tier']
header_value = str(user_context['subscription_tier'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('subscription_status'):
header_name = self.STANDARD_HEADERS['subscription_status']
header_value = str(user_context['subscription_status'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Demo session context
is_demo = user_context.get('is_demo', False)
if is_demo:
header_name = self.STANDARD_HEADERS['is_demo']
header_value = "true"
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('demo_session_id'):
header_name = self.STANDARD_HEADERS['demo_session_id']
header_value = str(user_context['demo_session_id'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('demo_account_type'):
header_name = self.STANDARD_HEADERS['demo_account_type']
header_value = str(user_context['demo_account_type'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Hierarchical access context
if tenant_id:
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
can_view_children = getattr(request.state, 'can_view_children', False)
header_name = self.STANDARD_HEADERS['tenant_access_type']
header_value = str(tenant_access_type)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
header_name = self.STANDARD_HEADERS['can_view_children']
header_value = str(can_view_children).lower()
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Parent tenant ID if hierarchical access
parent_tenant_id = getattr(request.state, 'parent_tenant_id', None)
if parent_tenant_id:
header_name = self.STANDARD_HEADERS['parent_tenant_id']
header_value = str(parent_tenant_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Gateway identification
header_name = self.STANDARD_HEADERS['forwarded_by']
header_value = "bakery-gateway"
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Request ID if available
request_id = getattr(request.state, 'request_id', None)
if request_id:
header_name = self.STANDARD_HEADERS['request_id']
header_value = str(request_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
logger.info("🔧 Injected context headers",
user_id=user_context.get('user_id'),
user_type=user_context.get('type', ''),
service_name=user_context.get('service', ''),
role=user_context.get('role', ''),
tenant_id=tenant_id,
is_demo=is_demo,
demo_session_id=user_context.get('demo_session_id', ''),
path=request.url.path)
return injected_headers
def _add_header(self, request: Request, header_name: str, header_value: str) -> None:
"""
Safely add header to request
"""
try:
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
except Exception as e:
logger.warning(f"Failed to add header {header_name}: {e}")
def get_forwardable_headers(self, request: Request) -> Dict[str, str]:
"""
Get headers that should be forwarded to downstream services
Includes both original request headers and injected context headers
"""
forwardable_headers = {}
# Add forwardable original headers
for header_name in self.FORWARDABLE_HEADERS:
header_value = request.headers.get(header_name)
if header_value:
forwardable_headers[header_name] = header_value
# Add injected context headers from request.state
if hasattr(request.state, 'injected_headers'):
for header_name, header_value in request.state.injected_headers.items():
forwardable_headers[header_name] = header_value
# Add authorization header if present
auth_header = request.headers.get('authorization')
if auth_header:
forwardable_headers['authorization'] = auth_header
return forwardable_headers
def get_all_headers_for_proxy(self, request: Request,
additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
"""
Get complete set of headers for proxying to downstream services
"""
headers = self.get_forwardable_headers(request)
# Add any additional headers
if additional_headers:
headers.update(additional_headers)
# Remove host header as it will be set by httpx
headers.pop('host', None)
return headers
def validate_required_headers(self, request: Request, required_headers: List[str]) -> bool:
"""
Validate that required headers are present
"""
missing_headers = []
for header_name in required_headers:
# Check in injected headers first
if hasattr(request.state, 'injected_headers'):
if header_name in request.state.injected_headers:
continue
# Check in request headers
if request.headers.get(header_name):
continue
missing_headers.append(header_name)
if missing_headers:
logger.warning(f"Missing required headers: {missing_headers}")
return False
return True
def get_header_value(self, request: Request, header_name: str,
default: Optional[str] = None) -> Optional[str]:
"""
Get header value from either injected headers or request headers
"""
# Check injected headers first
if hasattr(request.state, 'injected_headers'):
if header_name in request.state.injected_headers:
return request.state.injected_headers[header_name]
# Check request headers
return request.headers.get(header_name, default)
def add_header_for_middleware(self, request: Request, header_name: str, header_value: str) -> None:
"""
Allow middleware to add headers to the unified header system
This ensures all headers are available for proxying
"""
# Ensure injected_headers exists
if not hasattr(request.state, 'injected_headers'):
request.state.injected_headers = {}
# Add header to injected_headers
request.state.injected_headers[header_name] = header_value
# Also add to actual request headers for compatibility
try:
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
except Exception as e:
logger.warning(f"Failed to add header {header_name} to request headers: {e}")
logger.debug(f"Middleware added header: {header_name} = {header_value}")
# Global instance for easy access
header_manager = HeaderManager()