Files
bakery-ia/gateway/app/core/header_manager.py

346 lines
14 KiB
Python
Raw Normal View History

2026-01-12 22:15:11 +01:00
"""
Unified Header Management System for API Gateway
Centralized header injection, forwarding, and validation
"""
import structlog
from fastapi import Request
from typing import Dict, Any, Optional, List
logger = structlog.get_logger()
class HeaderManager:
"""
Centralized header management for consistent header handling across gateway
"""
# Standard header names (lowercase for consistency)
STANDARD_HEADERS = {
'user_id': 'x-user-id',
'user_email': 'x-user-email',
'user_role': 'x-user-role',
'user_type': 'x-user-type',
'service_name': 'x-service-name',
'tenant_id': 'x-tenant-id',
'subscription_tier': 'x-subscription-tier',
'subscription_status': 'x-subscription-status',
'is_demo': 'x-is-demo',
'demo_session_id': 'x-demo-session-id',
'demo_account_type': 'x-demo-account-type',
'tenant_access_type': 'x-tenant-access-type',
'can_view_children': 'x-can-view-children',
'parent_tenant_id': 'x-parent-tenant-id',
'forwarded_by': 'x-forwarded-by',
'request_id': 'x-request-id'
}
# Headers that should be sanitized/removed from incoming requests
SANITIZED_HEADERS = [
'x-subscription-',
'x-user-',
'x-tenant-',
'x-demo-',
'x-forwarded-by'
]
# Headers that should be forwarded to downstream services
FORWARDABLE_HEADERS = [
'authorization',
'content-type',
'accept',
'accept-language',
'user-agent',
2026-01-15 20:45:49 +01:00
'x-internal-service', # Required for internal service-to-service ML/alert triggers
'stripe-signature' # Required for Stripe webhook signature verification
2026-01-12 22:15:11 +01:00
]
def __init__(self):
self._initialized = False
def initialize(self):
"""Initialize header manager"""
if not self._initialized:
logger.info("HeaderManager initialized")
self._initialized = True
def sanitize_incoming_headers(self, request: Request) -> None:
"""
Remove sensitive headers from incoming request to prevent spoofing
"""
if not hasattr(request.headers, '_list'):
return
# Filter out headers that start with sanitized prefixes
sanitized_headers = [
(k, v) for k, v in request.headers.raw
if not any(k.decode().lower().startswith(prefix.lower())
for prefix in self.SANITIZED_HEADERS)
]
request.headers.__dict__["_list"] = sanitized_headers
logger.debug("Sanitized incoming headers")
def inject_context_headers(self, request: Request, user_context: Dict[str, Any],
tenant_id: Optional[str] = None) -> Dict[str, str]:
"""
Inject standardized context headers into request
Returns dict of injected headers for reference
"""
injected_headers = {}
# Ensure headers list exists
if not hasattr(request.headers, '_list'):
request.headers.__dict__["_list"] = []
# Store headers in request.state for cross-middleware access
request.state.injected_headers = {}
# User context headers
if user_context.get('user_id'):
header_name = self.STANDARD_HEADERS['user_id']
header_value = str(user_context['user_id'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('email'):
header_name = self.STANDARD_HEADERS['user_email']
header_value = str(user_context['email'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('role'):
header_name = self.STANDARD_HEADERS['user_role']
header_value = str(user_context['role'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# User type (service vs regular user)
if user_context.get('type'):
header_name = self.STANDARD_HEADERS['user_type']
header_value = str(user_context['type'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Service name for service tokens
if user_context.get('service'):
header_name = self.STANDARD_HEADERS['service_name']
header_value = str(user_context['service'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Tenant context
if tenant_id:
header_name = self.STANDARD_HEADERS['tenant_id']
header_value = str(tenant_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Subscription context
if user_context.get('subscription_tier'):
header_name = self.STANDARD_HEADERS['subscription_tier']
header_value = str(user_context['subscription_tier'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('subscription_status'):
header_name = self.STANDARD_HEADERS['subscription_status']
header_value = str(user_context['subscription_status'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Demo session context
is_demo = user_context.get('is_demo', False)
if is_demo:
header_name = self.STANDARD_HEADERS['is_demo']
header_value = "true"
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('demo_session_id'):
header_name = self.STANDARD_HEADERS['demo_session_id']
header_value = str(user_context['demo_session_id'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('demo_account_type'):
header_name = self.STANDARD_HEADERS['demo_account_type']
header_value = str(user_context['demo_account_type'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Hierarchical access context
if tenant_id:
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
can_view_children = getattr(request.state, 'can_view_children', False)
header_name = self.STANDARD_HEADERS['tenant_access_type']
header_value = str(tenant_access_type)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
header_name = self.STANDARD_HEADERS['can_view_children']
header_value = str(can_view_children).lower()
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Parent tenant ID if hierarchical access
parent_tenant_id = getattr(request.state, 'parent_tenant_id', None)
if parent_tenant_id:
header_name = self.STANDARD_HEADERS['parent_tenant_id']
header_value = str(parent_tenant_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Gateway identification
header_name = self.STANDARD_HEADERS['forwarded_by']
header_value = "bakery-gateway"
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Request ID if available
request_id = getattr(request.state, 'request_id', None)
if request_id:
header_name = self.STANDARD_HEADERS['request_id']
header_value = str(request_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
logger.info("🔧 Injected context headers",
user_id=user_context.get('user_id'),
user_type=user_context.get('type', ''),
service_name=user_context.get('service', ''),
role=user_context.get('role', ''),
tenant_id=tenant_id,
is_demo=is_demo,
demo_session_id=user_context.get('demo_session_id', ''),
path=request.url.path)
return injected_headers
def _add_header(self, request: Request, header_name: str, header_value: str) -> None:
"""
Safely add header to request
"""
try:
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
except Exception as e:
logger.warning(f"Failed to add header {header_name}: {e}")
def get_forwardable_headers(self, request: Request) -> Dict[str, str]:
"""
Get headers that should be forwarded to downstream services
Includes both original request headers and injected context headers
"""
forwardable_headers = {}
# Add forwardable original headers
for header_name in self.FORWARDABLE_HEADERS:
header_value = request.headers.get(header_name)
if header_value:
forwardable_headers[header_name] = header_value
# Add injected context headers from request.state
if hasattr(request.state, 'injected_headers'):
for header_name, header_value in request.state.injected_headers.items():
forwardable_headers[header_name] = header_value
# Add authorization header if present
auth_header = request.headers.get('authorization')
if auth_header:
forwardable_headers['authorization'] = auth_header
return forwardable_headers
def get_all_headers_for_proxy(self, request: Request,
additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
"""
Get complete set of headers for proxying to downstream services
"""
headers = self.get_forwardable_headers(request)
# Add any additional headers
if additional_headers:
headers.update(additional_headers)
# Remove host header as it will be set by httpx
headers.pop('host', None)
return headers
def validate_required_headers(self, request: Request, required_headers: List[str]) -> bool:
"""
Validate that required headers are present
"""
missing_headers = []
for header_name in required_headers:
# Check in injected headers first
if hasattr(request.state, 'injected_headers'):
if header_name in request.state.injected_headers:
continue
# Check in request headers
if request.headers.get(header_name):
continue
missing_headers.append(header_name)
if missing_headers:
logger.warning(f"Missing required headers: {missing_headers}")
return False
return True
def get_header_value(self, request: Request, header_name: str,
default: Optional[str] = None) -> Optional[str]:
"""
Get header value from either injected headers or request headers
"""
# Check injected headers first
if hasattr(request.state, 'injected_headers'):
if header_name in request.state.injected_headers:
return request.state.injected_headers[header_name]
# Check request headers
return request.headers.get(header_name, default)
def add_header_for_middleware(self, request: Request, header_name: str, header_value: str) -> None:
"""
Allow middleware to add headers to the unified header system
This ensures all headers are available for proxying
"""
# Ensure injected_headers exists
if not hasattr(request.state, 'injected_headers'):
request.state.injected_headers = {}
# Add header to injected_headers
request.state.injected_headers[header_name] = header_value
# Also add to actual request headers for compatibility
try:
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
except Exception as e:
logger.warning(f"Failed to add header {header_name} to request headers: {e}")
logger.debug(f"Middleware added header: {header_name} = {header_value}")
# Global instance for easy access
header_manager = HeaderManager()