""" Unified Header Management System for API Gateway Centralized header injection, forwarding, and validation """ import structlog from fastapi import Request from typing import Dict, Any, Optional, List logger = structlog.get_logger() class HeaderManager: """ Centralized header management for consistent header handling across gateway """ # Standard header names (lowercase for consistency) STANDARD_HEADERS = { 'user_id': 'x-user-id', 'user_email': 'x-user-email', 'user_role': 'x-user-role', 'user_type': 'x-user-type', 'service_name': 'x-service-name', 'tenant_id': 'x-tenant-id', 'subscription_tier': 'x-subscription-tier', 'subscription_status': 'x-subscription-status', 'is_demo': 'x-is-demo', 'demo_session_id': 'x-demo-session-id', 'demo_account_type': 'x-demo-account-type', 'tenant_access_type': 'x-tenant-access-type', 'can_view_children': 'x-can-view-children', 'parent_tenant_id': 'x-parent-tenant-id', 'forwarded_by': 'x-forwarded-by', 'request_id': 'x-request-id' } # Headers that should be sanitized/removed from incoming requests SANITIZED_HEADERS = [ 'x-subscription-', 'x-user-', 'x-tenant-', 'x-demo-', 'x-forwarded-by' ] # Headers that should be forwarded to downstream services FORWARDABLE_HEADERS = [ 'authorization', 'content-type', 'accept', 'accept-language', 'user-agent', 'x-internal-service' # Required for internal service-to-service ML/alert triggers ] def __init__(self): self._initialized = False def initialize(self): """Initialize header manager""" if not self._initialized: logger.info("HeaderManager initialized") self._initialized = True def sanitize_incoming_headers(self, request: Request) -> None: """ Remove sensitive headers from incoming request to prevent spoofing """ if not hasattr(request.headers, '_list'): return # Filter out headers that start with sanitized prefixes sanitized_headers = [ (k, v) for k, v in request.headers.raw if not any(k.decode().lower().startswith(prefix.lower()) for prefix in self.SANITIZED_HEADERS) ] request.headers.__dict__["_list"] = sanitized_headers logger.debug("Sanitized incoming headers") def inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None) -> Dict[str, str]: """ Inject standardized context headers into request Returns dict of injected headers for reference """ injected_headers = {} # Ensure headers list exists if not hasattr(request.headers, '_list'): request.headers.__dict__["_list"] = [] # Store headers in request.state for cross-middleware access request.state.injected_headers = {} # User context headers if user_context.get('user_id'): header_name = self.STANDARD_HEADERS['user_id'] header_value = str(user_context['user_id']) self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value if user_context.get('email'): header_name = self.STANDARD_HEADERS['user_email'] header_value = str(user_context['email']) self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value if user_context.get('role'): header_name = self.STANDARD_HEADERS['user_role'] header_value = str(user_context['role']) self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value # User type (service vs regular user) if user_context.get('type'): header_name = self.STANDARD_HEADERS['user_type'] header_value = str(user_context['type']) self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value # Service name for service tokens if user_context.get('service'): header_name = self.STANDARD_HEADERS['service_name'] header_value = str(user_context['service']) self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value # Tenant context if tenant_id: header_name = self.STANDARD_HEADERS['tenant_id'] header_value = str(tenant_id) self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value # Subscription context if user_context.get('subscription_tier'): header_name = self.STANDARD_HEADERS['subscription_tier'] header_value = str(user_context['subscription_tier']) self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value if user_context.get('subscription_status'): header_name = self.STANDARD_HEADERS['subscription_status'] header_value = str(user_context['subscription_status']) self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value # Demo session context is_demo = user_context.get('is_demo', False) if is_demo: header_name = self.STANDARD_HEADERS['is_demo'] header_value = "true" self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value if user_context.get('demo_session_id'): header_name = self.STANDARD_HEADERS['demo_session_id'] header_value = str(user_context['demo_session_id']) self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value if user_context.get('demo_account_type'): header_name = self.STANDARD_HEADERS['demo_account_type'] header_value = str(user_context['demo_account_type']) self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value # Hierarchical access context if tenant_id: tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct') can_view_children = getattr(request.state, 'can_view_children', False) header_name = self.STANDARD_HEADERS['tenant_access_type'] header_value = str(tenant_access_type) self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value header_name = self.STANDARD_HEADERS['can_view_children'] header_value = str(can_view_children).lower() self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value # Parent tenant ID if hierarchical access parent_tenant_id = getattr(request.state, 'parent_tenant_id', None) if parent_tenant_id: header_name = self.STANDARD_HEADERS['parent_tenant_id'] header_value = str(parent_tenant_id) self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value # Gateway identification header_name = self.STANDARD_HEADERS['forwarded_by'] header_value = "bakery-gateway" self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value # Request ID if available request_id = getattr(request.state, 'request_id', None) if request_id: header_name = self.STANDARD_HEADERS['request_id'] header_value = str(request_id) self._add_header(request, header_name, header_value) injected_headers[header_name] = header_value request.state.injected_headers[header_name] = header_value logger.info("🔧 Injected context headers", user_id=user_context.get('user_id'), user_type=user_context.get('type', ''), service_name=user_context.get('service', ''), role=user_context.get('role', ''), tenant_id=tenant_id, is_demo=is_demo, demo_session_id=user_context.get('demo_session_id', ''), path=request.url.path) return injected_headers def _add_header(self, request: Request, header_name: str, header_value: str) -> None: """ Safely add header to request """ try: request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode())) except Exception as e: logger.warning(f"Failed to add header {header_name}: {e}") def get_forwardable_headers(self, request: Request) -> Dict[str, str]: """ Get headers that should be forwarded to downstream services Includes both original request headers and injected context headers """ forwardable_headers = {} # Add forwardable original headers for header_name in self.FORWARDABLE_HEADERS: header_value = request.headers.get(header_name) if header_value: forwardable_headers[header_name] = header_value # Add injected context headers from request.state if hasattr(request.state, 'injected_headers'): for header_name, header_value in request.state.injected_headers.items(): forwardable_headers[header_name] = header_value # Add authorization header if present auth_header = request.headers.get('authorization') if auth_header: forwardable_headers['authorization'] = auth_header return forwardable_headers def get_all_headers_for_proxy(self, request: Request, additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: """ Get complete set of headers for proxying to downstream services """ headers = self.get_forwardable_headers(request) # Add any additional headers if additional_headers: headers.update(additional_headers) # Remove host header as it will be set by httpx headers.pop('host', None) return headers def validate_required_headers(self, request: Request, required_headers: List[str]) -> bool: """ Validate that required headers are present """ missing_headers = [] for header_name in required_headers: # Check in injected headers first if hasattr(request.state, 'injected_headers'): if header_name in request.state.injected_headers: continue # Check in request headers if request.headers.get(header_name): continue missing_headers.append(header_name) if missing_headers: logger.warning(f"Missing required headers: {missing_headers}") return False return True def get_header_value(self, request: Request, header_name: str, default: Optional[str] = None) -> Optional[str]: """ Get header value from either injected headers or request headers """ # Check injected headers first if hasattr(request.state, 'injected_headers'): if header_name in request.state.injected_headers: return request.state.injected_headers[header_name] # Check request headers return request.headers.get(header_name, default) def add_header_for_middleware(self, request: Request, header_name: str, header_value: str) -> None: """ Allow middleware to add headers to the unified header system This ensures all headers are available for proxying """ # Ensure injected_headers exists if not hasattr(request.state, 'injected_headers'): request.state.injected_headers = {} # Add header to injected_headers request.state.injected_headers[header_name] = header_value # Also add to actual request headers for compatibility try: request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode())) except Exception as e: logger.warning(f"Failed to add header {header_name} to request headers: {e}") logger.debug(f"Middleware added header: {header_name} = {header_value}") # Global instance for easy access header_manager = HeaderManager()