Add improvements 2
This commit is contained in:
345
gateway/app/core/header_manager.py
Normal file
345
gateway/app/core/header_manager.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
Unified Header Management System for API Gateway
|
||||
Centralized header injection, forwarding, and validation
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import Request
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class HeaderManager:
|
||||
"""
|
||||
Centralized header management for consistent header handling across gateway
|
||||
"""
|
||||
|
||||
# Standard header names (lowercase for consistency)
|
||||
STANDARD_HEADERS = {
|
||||
'user_id': 'x-user-id',
|
||||
'user_email': 'x-user-email',
|
||||
'user_role': 'x-user-role',
|
||||
'user_type': 'x-user-type',
|
||||
'service_name': 'x-service-name',
|
||||
'tenant_id': 'x-tenant-id',
|
||||
'subscription_tier': 'x-subscription-tier',
|
||||
'subscription_status': 'x-subscription-status',
|
||||
'is_demo': 'x-is-demo',
|
||||
'demo_session_id': 'x-demo-session-id',
|
||||
'demo_account_type': 'x-demo-account-type',
|
||||
'tenant_access_type': 'x-tenant-access-type',
|
||||
'can_view_children': 'x-can-view-children',
|
||||
'parent_tenant_id': 'x-parent-tenant-id',
|
||||
'forwarded_by': 'x-forwarded-by',
|
||||
'request_id': 'x-request-id'
|
||||
}
|
||||
|
||||
# Headers that should be sanitized/removed from incoming requests
|
||||
SANITIZED_HEADERS = [
|
||||
'x-subscription-',
|
||||
'x-user-',
|
||||
'x-tenant-',
|
||||
'x-demo-',
|
||||
'x-forwarded-by'
|
||||
]
|
||||
|
||||
# Headers that should be forwarded to downstream services
|
||||
FORWARDABLE_HEADERS = [
|
||||
'authorization',
|
||||
'content-type',
|
||||
'accept',
|
||||
'accept-language',
|
||||
'user-agent',
|
||||
'x-internal-service' # Required for internal service-to-service ML/alert triggers
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._initialized = False
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize header manager"""
|
||||
if not self._initialized:
|
||||
logger.info("HeaderManager initialized")
|
||||
self._initialized = True
|
||||
|
||||
def sanitize_incoming_headers(self, request: Request) -> None:
|
||||
"""
|
||||
Remove sensitive headers from incoming request to prevent spoofing
|
||||
"""
|
||||
if not hasattr(request.headers, '_list'):
|
||||
return
|
||||
|
||||
# Filter out headers that start with sanitized prefixes
|
||||
sanitized_headers = [
|
||||
(k, v) for k, v in request.headers.raw
|
||||
if not any(k.decode().lower().startswith(prefix.lower())
|
||||
for prefix in self.SANITIZED_HEADERS)
|
||||
]
|
||||
|
||||
request.headers.__dict__["_list"] = sanitized_headers
|
||||
logger.debug("Sanitized incoming headers")
|
||||
|
||||
def inject_context_headers(self, request: Request, user_context: Dict[str, Any],
|
||||
tenant_id: Optional[str] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Inject standardized context headers into request
|
||||
Returns dict of injected headers for reference
|
||||
"""
|
||||
injected_headers = {}
|
||||
|
||||
# Ensure headers list exists
|
||||
if not hasattr(request.headers, '_list'):
|
||||
request.headers.__dict__["_list"] = []
|
||||
|
||||
# Store headers in request.state for cross-middleware access
|
||||
request.state.injected_headers = {}
|
||||
|
||||
# User context headers
|
||||
if user_context.get('user_id'):
|
||||
header_name = self.STANDARD_HEADERS['user_id']
|
||||
header_value = str(user_context['user_id'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('email'):
|
||||
header_name = self.STANDARD_HEADERS['user_email']
|
||||
header_value = str(user_context['email'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('role'):
|
||||
header_name = self.STANDARD_HEADERS['user_role']
|
||||
header_value = str(user_context['role'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# User type (service vs regular user)
|
||||
if user_context.get('type'):
|
||||
header_name = self.STANDARD_HEADERS['user_type']
|
||||
header_value = str(user_context['type'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Service name for service tokens
|
||||
if user_context.get('service'):
|
||||
header_name = self.STANDARD_HEADERS['service_name']
|
||||
header_value = str(user_context['service'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Tenant context
|
||||
if tenant_id:
|
||||
header_name = self.STANDARD_HEADERS['tenant_id']
|
||||
header_value = str(tenant_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Subscription context
|
||||
if user_context.get('subscription_tier'):
|
||||
header_name = self.STANDARD_HEADERS['subscription_tier']
|
||||
header_value = str(user_context['subscription_tier'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('subscription_status'):
|
||||
header_name = self.STANDARD_HEADERS['subscription_status']
|
||||
header_value = str(user_context['subscription_status'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Demo session context
|
||||
is_demo = user_context.get('is_demo', False)
|
||||
if is_demo:
|
||||
header_name = self.STANDARD_HEADERS['is_demo']
|
||||
header_value = "true"
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('demo_session_id'):
|
||||
header_name = self.STANDARD_HEADERS['demo_session_id']
|
||||
header_value = str(user_context['demo_session_id'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('demo_account_type'):
|
||||
header_name = self.STANDARD_HEADERS['demo_account_type']
|
||||
header_value = str(user_context['demo_account_type'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Hierarchical access context
|
||||
if tenant_id:
|
||||
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
|
||||
can_view_children = getattr(request.state, 'can_view_children', False)
|
||||
|
||||
header_name = self.STANDARD_HEADERS['tenant_access_type']
|
||||
header_value = str(tenant_access_type)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
header_name = self.STANDARD_HEADERS['can_view_children']
|
||||
header_value = str(can_view_children).lower()
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Parent tenant ID if hierarchical access
|
||||
parent_tenant_id = getattr(request.state, 'parent_tenant_id', None)
|
||||
if parent_tenant_id:
|
||||
header_name = self.STANDARD_HEADERS['parent_tenant_id']
|
||||
header_value = str(parent_tenant_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Gateway identification
|
||||
header_name = self.STANDARD_HEADERS['forwarded_by']
|
||||
header_value = "bakery-gateway"
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Request ID if available
|
||||
request_id = getattr(request.state, 'request_id', None)
|
||||
if request_id:
|
||||
header_name = self.STANDARD_HEADERS['request_id']
|
||||
header_value = str(request_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
logger.info("🔧 Injected context headers",
|
||||
user_id=user_context.get('user_id'),
|
||||
user_type=user_context.get('type', ''),
|
||||
service_name=user_context.get('service', ''),
|
||||
role=user_context.get('role', ''),
|
||||
tenant_id=tenant_id,
|
||||
is_demo=is_demo,
|
||||
demo_session_id=user_context.get('demo_session_id', ''),
|
||||
path=request.url.path)
|
||||
|
||||
return injected_headers
|
||||
|
||||
def _add_header(self, request: Request, header_name: str, header_value: str) -> None:
|
||||
"""
|
||||
Safely add header to request
|
||||
"""
|
||||
try:
|
||||
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add header {header_name}: {e}")
|
||||
|
||||
def get_forwardable_headers(self, request: Request) -> Dict[str, str]:
|
||||
"""
|
||||
Get headers that should be forwarded to downstream services
|
||||
Includes both original request headers and injected context headers
|
||||
"""
|
||||
forwardable_headers = {}
|
||||
|
||||
# Add forwardable original headers
|
||||
for header_name in self.FORWARDABLE_HEADERS:
|
||||
header_value = request.headers.get(header_name)
|
||||
if header_value:
|
||||
forwardable_headers[header_name] = header_value
|
||||
|
||||
# Add injected context headers from request.state
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
for header_name, header_value in request.state.injected_headers.items():
|
||||
forwardable_headers[header_name] = header_value
|
||||
|
||||
# Add authorization header if present
|
||||
auth_header = request.headers.get('authorization')
|
||||
if auth_header:
|
||||
forwardable_headers['authorization'] = auth_header
|
||||
|
||||
return forwardable_headers
|
||||
|
||||
def get_all_headers_for_proxy(self, request: Request,
|
||||
additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Get complete set of headers for proxying to downstream services
|
||||
"""
|
||||
headers = self.get_forwardable_headers(request)
|
||||
|
||||
# Add any additional headers
|
||||
if additional_headers:
|
||||
headers.update(additional_headers)
|
||||
|
||||
# Remove host header as it will be set by httpx
|
||||
headers.pop('host', None)
|
||||
|
||||
return headers
|
||||
|
||||
def validate_required_headers(self, request: Request, required_headers: List[str]) -> bool:
|
||||
"""
|
||||
Validate that required headers are present
|
||||
"""
|
||||
missing_headers = []
|
||||
|
||||
for header_name in required_headers:
|
||||
# Check in injected headers first
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
if header_name in request.state.injected_headers:
|
||||
continue
|
||||
|
||||
# Check in request headers
|
||||
if request.headers.get(header_name):
|
||||
continue
|
||||
|
||||
missing_headers.append(header_name)
|
||||
|
||||
if missing_headers:
|
||||
logger.warning(f"Missing required headers: {missing_headers}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_header_value(self, request: Request, header_name: str,
|
||||
default: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Get header value from either injected headers or request headers
|
||||
"""
|
||||
# Check injected headers first
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
if header_name in request.state.injected_headers:
|
||||
return request.state.injected_headers[header_name]
|
||||
|
||||
# Check request headers
|
||||
return request.headers.get(header_name, default)
|
||||
|
||||
def add_header_for_middleware(self, request: Request, header_name: str, header_value: str) -> None:
|
||||
"""
|
||||
Allow middleware to add headers to the unified header system
|
||||
This ensures all headers are available for proxying
|
||||
"""
|
||||
# Ensure injected_headers exists
|
||||
if not hasattr(request.state, 'injected_headers'):
|
||||
request.state.injected_headers = {}
|
||||
|
||||
# Add header to injected_headers
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Also add to actual request headers for compatibility
|
||||
try:
|
||||
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add header {header_name} to request headers: {e}")
|
||||
|
||||
logger.debug(f"Middleware added header: {header_name} = {header_value}")
|
||||
|
||||
|
||||
# Global instance for easy access
|
||||
header_manager = HeaderManager()
|
||||
Reference in New Issue
Block a user