Add improvements 2

This commit is contained in:
Urtzi Alfaro
2026-01-12 22:15:11 +01:00
parent 230bbe6a19
commit b931a5c45e
40 changed files with 1820 additions and 887 deletions

View File

@@ -352,6 +352,25 @@ headers = {
- **Caching**: Gateway caches validated service tokens for 5 minutes
- **No Additional HTTP Calls**: Service auth happens locally at gateway
### Unified Header Management System
The gateway uses a **centralized HeaderManager** for consistent header handling across all middleware and proxy layers.
**Key Features:**
- Standardized header names and conventions
- Automatic header sanitization to prevent spoofing
- Unified header injection and forwarding
- Cross-middleware header access via `request.state.injected_headers`
- Consistent logging and error handling
**Standard Headers:**
- `x-user-id`, `x-user-email`, `x-user-role`, `x-user-type`
- `x-service-name`, `x-tenant-id`
- `x-subscription-tier`, `x-subscription-status`
- `x-is-demo`, `x-demo-session-id`, `x-demo-account-type`
- `x-tenant-access-type`, `x-can-view-children`, `x-parent-tenant-id`
- `x-forwarded-by`, `x-request-id`
### Context Header Injection
When a service token is validated, the gateway injects these headers for downstream services:

View File

@@ -0,0 +1,345 @@
"""
Unified Header Management System for API Gateway
Centralized header injection, forwarding, and validation
"""
import structlog
from fastapi import Request
from typing import Dict, Any, Optional, List
logger = structlog.get_logger()
class HeaderManager:
"""
Centralized header management for consistent header handling across gateway
"""
# Standard header names (lowercase for consistency)
STANDARD_HEADERS = {
'user_id': 'x-user-id',
'user_email': 'x-user-email',
'user_role': 'x-user-role',
'user_type': 'x-user-type',
'service_name': 'x-service-name',
'tenant_id': 'x-tenant-id',
'subscription_tier': 'x-subscription-tier',
'subscription_status': 'x-subscription-status',
'is_demo': 'x-is-demo',
'demo_session_id': 'x-demo-session-id',
'demo_account_type': 'x-demo-account-type',
'tenant_access_type': 'x-tenant-access-type',
'can_view_children': 'x-can-view-children',
'parent_tenant_id': 'x-parent-tenant-id',
'forwarded_by': 'x-forwarded-by',
'request_id': 'x-request-id'
}
# Headers that should be sanitized/removed from incoming requests
SANITIZED_HEADERS = [
'x-subscription-',
'x-user-',
'x-tenant-',
'x-demo-',
'x-forwarded-by'
]
# Headers that should be forwarded to downstream services
FORWARDABLE_HEADERS = [
'authorization',
'content-type',
'accept',
'accept-language',
'user-agent',
'x-internal-service' # Required for internal service-to-service ML/alert triggers
]
def __init__(self):
self._initialized = False
def initialize(self):
"""Initialize header manager"""
if not self._initialized:
logger.info("HeaderManager initialized")
self._initialized = True
def sanitize_incoming_headers(self, request: Request) -> None:
"""
Remove sensitive headers from incoming request to prevent spoofing
"""
if not hasattr(request.headers, '_list'):
return
# Filter out headers that start with sanitized prefixes
sanitized_headers = [
(k, v) for k, v in request.headers.raw
if not any(k.decode().lower().startswith(prefix.lower())
for prefix in self.SANITIZED_HEADERS)
]
request.headers.__dict__["_list"] = sanitized_headers
logger.debug("Sanitized incoming headers")
def inject_context_headers(self, request: Request, user_context: Dict[str, Any],
tenant_id: Optional[str] = None) -> Dict[str, str]:
"""
Inject standardized context headers into request
Returns dict of injected headers for reference
"""
injected_headers = {}
# Ensure headers list exists
if not hasattr(request.headers, '_list'):
request.headers.__dict__["_list"] = []
# Store headers in request.state for cross-middleware access
request.state.injected_headers = {}
# User context headers
if user_context.get('user_id'):
header_name = self.STANDARD_HEADERS['user_id']
header_value = str(user_context['user_id'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('email'):
header_name = self.STANDARD_HEADERS['user_email']
header_value = str(user_context['email'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('role'):
header_name = self.STANDARD_HEADERS['user_role']
header_value = str(user_context['role'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# User type (service vs regular user)
if user_context.get('type'):
header_name = self.STANDARD_HEADERS['user_type']
header_value = str(user_context['type'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Service name for service tokens
if user_context.get('service'):
header_name = self.STANDARD_HEADERS['service_name']
header_value = str(user_context['service'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Tenant context
if tenant_id:
header_name = self.STANDARD_HEADERS['tenant_id']
header_value = str(tenant_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Subscription context
if user_context.get('subscription_tier'):
header_name = self.STANDARD_HEADERS['subscription_tier']
header_value = str(user_context['subscription_tier'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('subscription_status'):
header_name = self.STANDARD_HEADERS['subscription_status']
header_value = str(user_context['subscription_status'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Demo session context
is_demo = user_context.get('is_demo', False)
if is_demo:
header_name = self.STANDARD_HEADERS['is_demo']
header_value = "true"
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('demo_session_id'):
header_name = self.STANDARD_HEADERS['demo_session_id']
header_value = str(user_context['demo_session_id'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('demo_account_type'):
header_name = self.STANDARD_HEADERS['demo_account_type']
header_value = str(user_context['demo_account_type'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Hierarchical access context
if tenant_id:
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
can_view_children = getattr(request.state, 'can_view_children', False)
header_name = self.STANDARD_HEADERS['tenant_access_type']
header_value = str(tenant_access_type)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
header_name = self.STANDARD_HEADERS['can_view_children']
header_value = str(can_view_children).lower()
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Parent tenant ID if hierarchical access
parent_tenant_id = getattr(request.state, 'parent_tenant_id', None)
if parent_tenant_id:
header_name = self.STANDARD_HEADERS['parent_tenant_id']
header_value = str(parent_tenant_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Gateway identification
header_name = self.STANDARD_HEADERS['forwarded_by']
header_value = "bakery-gateway"
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Request ID if available
request_id = getattr(request.state, 'request_id', None)
if request_id:
header_name = self.STANDARD_HEADERS['request_id']
header_value = str(request_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
logger.info("🔧 Injected context headers",
user_id=user_context.get('user_id'),
user_type=user_context.get('type', ''),
service_name=user_context.get('service', ''),
role=user_context.get('role', ''),
tenant_id=tenant_id,
is_demo=is_demo,
demo_session_id=user_context.get('demo_session_id', ''),
path=request.url.path)
return injected_headers
def _add_header(self, request: Request, header_name: str, header_value: str) -> None:
"""
Safely add header to request
"""
try:
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
except Exception as e:
logger.warning(f"Failed to add header {header_name}: {e}")
def get_forwardable_headers(self, request: Request) -> Dict[str, str]:
"""
Get headers that should be forwarded to downstream services
Includes both original request headers and injected context headers
"""
forwardable_headers = {}
# Add forwardable original headers
for header_name in self.FORWARDABLE_HEADERS:
header_value = request.headers.get(header_name)
if header_value:
forwardable_headers[header_name] = header_value
# Add injected context headers from request.state
if hasattr(request.state, 'injected_headers'):
for header_name, header_value in request.state.injected_headers.items():
forwardable_headers[header_name] = header_value
# Add authorization header if present
auth_header = request.headers.get('authorization')
if auth_header:
forwardable_headers['authorization'] = auth_header
return forwardable_headers
def get_all_headers_for_proxy(self, request: Request,
additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
"""
Get complete set of headers for proxying to downstream services
"""
headers = self.get_forwardable_headers(request)
# Add any additional headers
if additional_headers:
headers.update(additional_headers)
# Remove host header as it will be set by httpx
headers.pop('host', None)
return headers
def validate_required_headers(self, request: Request, required_headers: List[str]) -> bool:
"""
Validate that required headers are present
"""
missing_headers = []
for header_name in required_headers:
# Check in injected headers first
if hasattr(request.state, 'injected_headers'):
if header_name in request.state.injected_headers:
continue
# Check in request headers
if request.headers.get(header_name):
continue
missing_headers.append(header_name)
if missing_headers:
logger.warning(f"Missing required headers: {missing_headers}")
return False
return True
def get_header_value(self, request: Request, header_name: str,
default: Optional[str] = None) -> Optional[str]:
"""
Get header value from either injected headers or request headers
"""
# Check injected headers first
if hasattr(request.state, 'injected_headers'):
if header_name in request.state.injected_headers:
return request.state.injected_headers[header_name]
# Check request headers
return request.headers.get(header_name, default)
def add_header_for_middleware(self, request: Request, header_name: str, header_value: str) -> None:
"""
Allow middleware to add headers to the unified header system
This ensures all headers are available for proxying
"""
# Ensure injected_headers exists
if not hasattr(request.state, 'injected_headers'):
request.state.injected_headers = {}
# Add header to injected_headers
request.state.injected_headers[header_name] = header_value
# Also add to actual request headers for compatibility
try:
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
except Exception as e:
logger.warning(f"Failed to add header {header_name} to request headers: {e}")
logger.debug(f"Middleware added header: {header_name} = {header_value}")
# Global instance for easy access
header_manager = HeaderManager()

View File

@@ -16,6 +16,7 @@ from shared.redis_utils import initialize_redis, close_redis, get_redis_client
from shared.service_base import StandardFastAPIService
from app.core.config import settings
from app.core.header_manager import header_manager
from app.middleware.request_id import RequestIDMiddleware
from app.middleware.auth import AuthMiddleware
from app.middleware.logging import LoggingMiddleware
@@ -50,6 +51,10 @@ class GatewayService(StandardFastAPIService):
"""Custom startup logic for Gateway"""
global redis_client
# Initialize HeaderManager
header_manager.initialize()
logger.info("HeaderManager initialized")
# Initialize Redis
try:
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)

View File

@@ -14,6 +14,7 @@ import httpx
import json
from app.core.config import settings
from app.core.header_manager import header_manager
from shared.auth.jwt_handler import JWTHandler
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
@@ -60,15 +61,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
if request.method == "OPTIONS":
return await call_next(request)
# SECURITY: Remove any incoming x-subscription-* headers
# These will be re-injected from verified JWT only
sanitized_headers = [
(k, v) for k, v in request.headers.raw
if not k.decode().lower().startswith('x-subscription-')
and not k.decode().lower().startswith('x-user-')
and not k.decode().lower().startswith('x-tenant-')
]
request.headers.__dict__["_list"] = sanitized_headers
# SECURITY: Remove any incoming sensitive headers using HeaderManager
header_manager.sanitize_incoming_headers(request)
# Skip authentication for public routes
if self._is_public_route(request.url.path):
@@ -573,109 +567,13 @@ class AuthMiddleware(BaseHTTPMiddleware):
async def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
"""
Inject user and tenant context headers for downstream services
ENHANCED: Added logging to verify header injection
Inject user and tenant context headers for downstream services using unified HeaderManager
"""
# Enhanced logging for debugging
logger.info(
"🔧 Injecting context headers",
user_id=user_context.get("user_id"),
user_type=user_context.get("type", ""),
service_name=user_context.get("service", ""),
role=user_context.get("role", ""),
tenant_id=tenant_id,
is_demo=user_context.get("is_demo", False),
demo_session_id=user_context.get("demo_session_id", ""),
path=request.url.path
)
# Add user context headers
logger.debug(f"DEBUG: Injecting headers for user: {user_context.get('user_id')}, is_demo: {user_context.get('is_demo', False)}")
logger.debug(f"DEBUG: request.headers object id: {id(request.headers)}, _list id: {id(request.headers.__dict__.get('_list', []))}")
# Use unified HeaderManager for consistent header injection
injected_headers = header_manager.inject_context_headers(request, user_context, tenant_id)
# Store headers in request.state for cross-middleware access
request.state.injected_headers = {
"x-user-id": user_context["user_id"],
"x-user-email": user_context["email"],
"x-user-role": user_context.get("role", "user")
}
request.headers.__dict__["_list"].append((
b"x-user-id", user_context["user_id"].encode()
))
request.headers.__dict__["_list"].append((
b"x-user-email", user_context["email"].encode()
))
user_role = user_context.get("role", "user")
request.headers.__dict__["_list"].append((
b"x-user-role", user_role.encode()
))
user_type = user_context.get("type", "")
if user_type:
request.headers.__dict__["_list"].append((
b"x-user-type", user_type.encode()
))
service_name = user_context.get("service", "")
if service_name:
request.headers.__dict__["_list"].append((
b"x-service-name", service_name.encode()
))
# Add tenant context if available
if tenant_id:
request.headers.__dict__["_list"].append((
b"x-tenant-id", tenant_id.encode()
))
# Add subscription tier if available
subscription_tier = user_context.get("subscription_tier", "")
if subscription_tier:
request.headers.__dict__["_list"].append((
b"x-subscription-tier", subscription_tier.encode()
))
# Add is_demo flag for demo sessions
is_demo = user_context.get("is_demo", False)
logger.debug(f"DEBUG: is_demo value: {is_demo}, type: {type(is_demo)}")
if is_demo:
logger.info(f"🎭 Adding demo session headers",
demo_session_id=user_context.get("demo_session_id", ""),
demo_account_type=user_context.get("demo_account_type", ""),
path=request.url.path)
request.headers.__dict__["_list"].append((
b"x-is-demo", b"true"
))
else:
logger.debug(f"DEBUG: Not adding demo headers because is_demo is: {is_demo}")
# Add demo session context headers for backend services
demo_session_id = user_context.get("demo_session_id", "")
if demo_session_id:
request.headers.__dict__["_list"].append((
b"x-demo-session-id", demo_session_id.encode()
))
demo_account_type = user_context.get("demo_account_type", "")
if demo_account_type:
request.headers.__dict__["_list"].append((
b"x-demo-account-type", demo_account_type.encode()
))
# Add hierarchical access headers if tenant context exists
if tenant_id:
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
can_view_children = getattr(request.state, 'can_view_children', False)
request.headers.__dict__["_list"].append((
b"x-tenant-access-type", tenant_access_type.encode()
))
request.headers.__dict__["_list"].append((
b"x-can-view-children", str(can_view_children).encode()
))
# If this is hierarchical access, include parent tenant ID
# Get parent tenant ID from the auth service if available
try:
@@ -689,17 +587,16 @@ class AuthMiddleware(BaseHTTPMiddleware):
hierarchy_data = response.json()
parent_tenant_id = hierarchy_data.get("parent_tenant_id")
if parent_tenant_id:
request.headers.__dict__["_list"].append((
b"x-parent-tenant-id", parent_tenant_id.encode()
))
# Add parent tenant ID using HeaderManager for consistency
header_name = header_manager.STANDARD_HEADERS['parent_tenant_id']
header_value = str(parent_tenant_id)
header_manager.add_header_for_middleware(request, header_name, header_value)
logger.info(f"Added parent tenant ID header: {parent_tenant_id}")
except Exception as e:
logger.warning(f"Failed to get parent tenant ID: {e}")
pass
# Add gateway identification
request.headers.__dict__["_list"].append((
b"x-forwarded-by", b"bakery-gateway"
))
return injected_headers
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
"""

View File

@@ -45,8 +45,17 @@ class APIRateLimitMiddleware(BaseHTTPMiddleware):
return await call_next(request)
try:
# Get subscription tier
subscription_tier = await self._get_subscription_tier(tenant_id, request)
# Get subscription tier from headers (added by AuthMiddleware)
subscription_tier = request.headers.get("x-subscription-tier")
if not subscription_tier:
# Fallback: get from request state if headers not available
subscription_tier = getattr(request.state, "subscription_tier", None)
if not subscription_tier:
# Final fallback: get from tenant service (should rarely happen)
subscription_tier = await self._get_subscription_tier(tenant_id, request)
logger.warning(f"Subscription tier not found in headers or state, fetched from tenant service: {subscription_tier}")
# Get quota limit for tier
quota_limit = self._get_quota_limit(subscription_tier)

View File

@@ -9,6 +9,8 @@ from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from app.core.header_manager import header_manager
logger = structlog.get_logger()
@@ -40,11 +42,9 @@ class RequestIDMiddleware(BaseHTTPMiddleware):
# Bind request ID to structured logger context
logger_ctx = logger.bind(request_id=request_id)
# Inject request ID header for downstream services
# This is done by modifying the headers that will be forwarded
request.headers.__dict__["_list"].append((
b"x-request-id", request_id.encode()
))
# Inject request ID header for downstream services using HeaderManager
# Note: This runs early in middleware chain, so we use add_header_for_middleware
header_manager.add_header_for_middleware(request, "x-request-id", request_id)
# Log request start
logger_ctx.info(

View File

@@ -15,6 +15,7 @@ import asyncio
from datetime import datetime, timezone
from app.core.config import settings
from app.core.header_manager import header_manager
from app.utils.subscription_error_responses import create_upgrade_required_response
logger = structlog.get_logger()
@@ -178,7 +179,10 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
r'/api/v1/subscriptions/.*', # Subscription management itself
r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info
r'/docs.*',
r'/openapi\.json'
r'/openapi\.json',
# Training monitoring endpoints (WebSocket and status checks)
r'/api/v1/tenants/[^/]+/training/jobs/.*/live.*', # WebSocket endpoint
r'/api/v1/tenants/[^/]+/training/jobs/.*/status.*', # Status polling endpoint
]
# Skip OPTIONS requests (CORS preflight)
@@ -275,21 +279,11 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
'current_tier': current_tier
}
# Use the same authentication pattern as gateway routes for fallback
headers = dict(request.headers)
headers.pop("host", None)
# Use unified HeaderManager for consistent header handling
headers = header_manager.get_all_headers_for_proxy(request)
# Extract user_id for logging (fallback path)
user_id = 'unknown'
# Add user context headers if available
if hasattr(request.state, 'user') and request.state.user:
user = request.state.user
user_id = str(user.get('user_id', 'unknown'))
headers["x-user-id"] = user_id
headers["x-user-email"] = str(user.get('email', ''))
headers["x-user-role"] = str(user.get('role', 'user'))
headers["x-user-full-name"] = str(user.get('full_name', ''))
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
user_id = header_manager.get_header_value(request, 'x-user-id', 'unknown')
# Call tenant service fast tier endpoint with caching (fallback for old tokens)
timeout_config = httpx.Timeout(

View File

@@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse
from typing import Dict, Any
from app.core.config import settings
from app.core.header_manager import header_manager
from app.core.service_discovery import ServiceDiscovery
from shared.monitoring.metrics import MetricsCollector
@@ -136,107 +137,32 @@ class AuthProxy:
return AUTH_SERVICE_URL
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
# Remove hop-by-hop headers
hop_by_hop_headers = {
'connection', 'keep-alive', 'proxy-authenticate',
'proxy-authorization', 'te', 'trailers', 'upgrade'
}
# Convert headers to dict - get ALL headers including those added by middleware
# Middleware adds headers to _list, so we need to read from there
logger.debug(f"DEBUG: headers type: {type(headers)}, has _list: {hasattr(headers, '_list')}, has raw: {hasattr(headers, 'raw')}")
logger.debug(f"DEBUG: headers.__dict__ keys: {list(headers.__dict__.keys())}")
logger.debug(f"DEBUG: '_list' in headers.__dict__: {'_list' in headers.__dict__}")
if hasattr(headers, '_list'):
logger.debug(f"DEBUG: Entering _list branch")
logger.debug(f"DEBUG: headers object id: {id(headers)}, _list id: {id(headers.__dict__.get('_list', []))}")
# Get headers from the _list where middleware adds them
all_headers_list = headers.__dict__.get('_list', [])
logger.debug(f"DEBUG: _list length: {len(all_headers_list)}")
# Debug: Show first few headers in the list
debug_headers = []
for i, (k, v) in enumerate(all_headers_list):
if i < 5: # Show first 5 headers for debugging
"""Prepare headers for forwarding using unified HeaderManager"""
# Use unified HeaderManager to get all headers
if request:
all_headers = header_manager.get_all_headers_for_proxy(request)
logger.debug(f"DEBUG: Added headers from HeaderManager: {list(all_headers.keys())}")
else:
# Fallback: convert headers to dict manually
all_headers = {}
if hasattr(headers, '_list'):
for k, v in headers.__dict__.get('_list', []):
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
debug_headers.append(f"{key}: {value}")
logger.debug(f"DEBUG: First headers in _list: {debug_headers}")
# Convert to dict for easier processing
all_headers = {}
for k, v in all_headers_list:
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
# Debug: Show if x-user-id and x-is-demo are in the dict
logger.debug(f"DEBUG: x-user-id in all_headers: {'x-user-id' in all_headers}, x-is-demo in all_headers: {'x-is-demo' in all_headers}")
logger.debug(f"DEBUG: all_headers keys: {list(all_headers.keys())[:10]}...") # Show first 10 keys
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
# Check if headers are missing and try to get them from request.state
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
logger.debug(f"DEBUG: Found injected_headers in request.state: {request.state.injected_headers}")
# Add missing headers from request.state
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
logger.debug(f"DEBUG: Added x-user-id from request.state: {all_headers['x-user-id']}")
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
logger.debug(f"DEBUG: Added x-user-email from request.state: {all_headers['x-user-email']}")
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
logger.debug(f"DEBUG: Added x-user-role from request.state: {all_headers['x-user-role']}")
# Add is_demo flag if this is a demo session
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
all_headers['x-is-demo'] = 'true'
logger.debug(f"DEBUG: Added x-is-demo from request.state.is_demo_session")
# Filter out hop-by-hop headers
filtered_headers = {
k: v for k, v in all_headers.items()
if k.lower() not in hop_by_hop_headers
}
elif hasattr(headers, 'raw'):
logger.debug(f"DEBUG: Entering raw branch")
# Filter out hop-by-hop headers
filtered_headers = {
k: v for k, v in all_headers.items()
if k.lower() not in hop_by_hop_headers
}
elif hasattr(headers, 'raw'):
# Fallback to raw headers if _list not available
all_headers = {
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
for k, v in headers.raw
}
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
filtered_headers = {
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
for k, v in headers.raw
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
}
else:
# Handle case where headers is already a dict
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {headers.get('x-user-id', 'MISSING')}, x_is_demo: {headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {headers.get('x-demo-session-id', 'MISSING')}, headers: {list(headers.keys())}")
filtered_headers = {
k: v for k, v in headers.items()
if k.lower() not in hop_by_hop_headers
}
# Add gateway identifier
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
filtered_headers['X-Gateway-Version'] = '1.0.0'
return filtered_headers
all_headers[key] = value
elif hasattr(headers, 'raw'):
for k, v in headers.raw:
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
else:
# Headers is already a dict
all_headers = dict(headers)
# Debug logging
logger.info(f"📤 Forwarding headers - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
return all_headers
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
"""Prepare response headers"""

View File

@@ -8,6 +8,7 @@ import httpx
import structlog
from app.core.config import settings
from app.core.header_manager import header_manager
logger = structlog.get_logger()
@@ -29,12 +30,8 @@ async def proxy_demo_service(path: str, request: Request):
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
# Forward headers (excluding host)
headers = {
key: value
for key, value in request.headers.items()
if key.lower() not in ["host", "content-length"]
}
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
try:
async with httpx.AsyncClient(timeout=30.0) as client:

View File

@@ -5,6 +5,7 @@ from fastapi.responses import JSONResponse
import httpx
import structlog
from app.core.config import settings
from app.core.header_manager import header_manager
logger = structlog.get_logger()
router = APIRouter()
@@ -26,12 +27,8 @@ async def proxy_geocoding(request: Request, path: str):
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
# Forward headers (excluding host)
headers = {
key: value
for key, value in request.headers.items()
if key.lower() not in ["host", "content-length"]
}
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Make the proxied request
async with httpx.AsyncClient(timeout=30.0) as client:

View File

@@ -8,6 +8,7 @@ from fastapi.responses import JSONResponse
import httpx
import structlog
from app.core.config import settings
from app.core.header_manager import header_manager
logger = structlog.get_logger()
router = APIRouter()
@@ -44,12 +45,8 @@ async def proxy_poi_context(request: Request, path: str):
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
# Copy headers (exclude host and content-length as they'll be set by httpx)
headers = {
key: value
for key, value in request.headers.items()
if key.lower() not in ["host", "content-length"]
}
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Make the request to the external service
async with httpx.AsyncClient(timeout=60.0) as client:

View File

@@ -8,6 +8,7 @@ import httpx
import logging
from app.core.config import settings
from app.core.header_manager import header_manager
logger = logging.getLogger(__name__)
router = APIRouter()
@@ -45,9 +46,8 @@ async def _proxy_to_pos_service(request: Request, target_path: str):
try:
url = f"{settings.POS_SERVICE_URL}{target_path}"
# Forward headers
headers = dict(request.headers)
headers.pop("host", None)
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Add query parameters
params = dict(request.query_params)

View File

@@ -9,6 +9,7 @@ import logging
from typing import Optional
from app.core.config import settings
from app.core.header_manager import header_manager
logger = logging.getLogger(__name__)
router = APIRouter()
@@ -98,29 +99,13 @@ async def _proxy_request(request: Request, target_path: str, service_url: str):
try:
url = f"{service_url}{target_path}"
# Forward headers and add user/tenant context
headers = dict(request.headers)
headers.pop("host", None)
# Add user context headers if available
if hasattr(request.state, 'user') and request.state.user:
user = request.state.user
headers["x-user-id"] = str(user.get('user_id', ''))
headers["x-user-email"] = str(user.get('email', ''))
headers["x-user-role"] = str(user.get('role', 'user'))
headers["x-user-full-name"] = str(user.get('full_name', ''))
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
# Add subscription context headers
if user.get('subscription_tier'):
headers["x-subscription-tier"] = str(user.get('subscription_tier', ''))
logger.debug(f"Forwarding subscription tier: {user.get('subscription_tier')}")
if user.get('subscription_status'):
headers["x-subscription-status"] = str(user.get('subscription_status', ''))
logger.debug(f"Forwarding subscription status: {user.get('subscription_status')}")
logger.info(f"Forwarding subscription request to {url} with user context: user_id={user.get('user_id')}, email={user.get('email')}, subscription_tier={user.get('subscription_tier', 'not_set')}")
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Debug logging
user_context = getattr(request.state, 'user', None)
if user_context:
logger.info(f"Forwarding subscription request to {url} with user context: user_id={user_context.get('user_id')}, email={user_context.get('email')}, subscription_tier={user_context.get('subscription_tier', 'not_set')}")
else:
logger.warning(f"No user context available when forwarding subscription request to {url}")

View File

@@ -10,6 +10,7 @@ import logging
from typing import Optional
from app.core.config import settings
from app.core.header_manager import header_manager
logger = logging.getLogger(__name__)
router = APIRouter()
@@ -715,36 +716,18 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
try:
url = f"{service_url}{target_path}"
# Forward headers and add user/tenant context
headers = dict(request.headers)
headers.pop("host", None)
# Add tenant ID header if provided
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Add tenant ID header if provided (override if needed)
if tenant_id:
headers["X-Tenant-ID"] = tenant_id
# Add user context headers if available
if hasattr(request.state, 'user') and request.state.user:
user = request.state.user
headers["x-user-id"] = str(user.get('user_id', ''))
headers["x-user-email"] = str(user.get('email', ''))
headers["x-user-role"] = str(user.get('role', 'user'))
headers["x-user-full-name"] = str(user.get('full_name', ''))
headers["x-tenant-id"] = tenant_id or str(user.get('tenant_id', ''))
# Add subscription context headers
if user.get('subscription_tier'):
headers["x-subscription-tier"] = str(user.get('subscription_tier', ''))
logger.debug(f"Forwarding subscription tier: {user.get('subscription_tier')}")
if user.get('subscription_status'):
headers["x-subscription-status"] = str(user.get('subscription_status', ''))
logger.debug(f"Forwarding subscription status: {user.get('subscription_status')}")
# Debug logging
logger.info(f"Forwarding request to {url} with user context: user_id={user.get('user_id')}, email={user.get('email')}, tenant_id={tenant_id}, subscription_tier={user.get('subscription_tier', 'not_set')}")
headers["x-tenant-id"] = tenant_id
# Debug logging
user_context = getattr(request.state, 'user', None)
if user_context:
logger.info(f"Forwarding request to {url} with user context: user_id={user_context.get('user_id')}, email={user_context.get('email')}, tenant_id={tenant_id}, subscription_tier={user_context.get('subscription_tier', 'not_set')}")
else:
# Debug logging when no user context available
logger.warning(f"No user context available when forwarding request to {url}. request.state.user: {getattr(request.state, 'user', 'NOT_SET')}")
# Get request body if present
@@ -782,9 +765,10 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
logger.info(f"Forwarding multipart request with files={list(files.keys()) if files else None}, data={list(data.keys()) if data else None}")
# Remove content-type from headers - httpx will set it with new boundary
headers.pop("content-type", None)
headers.pop("content-length", None)
# For multipart requests, we need to get fresh headers since httpx will set content-type
# Get all headers again to ensure we have the complete set
headers = header_manager.get_all_headers_for_proxy(request)
# httpx will automatically set content-type for multipart, so we don't need to remove it
else:
# For other content types, use body as before
body = await request.body()

View File

@@ -13,6 +13,7 @@ from typing import Dict, Any
import json
from app.core.config import settings
from app.core.header_manager import header_manager
from app.core.service_discovery import ServiceDiscovery
from shared.monitoring.metrics import MetricsCollector
@@ -136,64 +137,28 @@ class UserProxy:
return AUTH_SERVICE_URL
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
# Remove hop-by-hop headers
hop_by_hop_headers = {
'connection', 'keep-alive', 'proxy-authenticate',
'proxy-authorization', 'te', 'trailers', 'upgrade'
}
# Convert headers to dict if it's a Headers object
# This ensures we get ALL headers including those added by middleware
if hasattr(headers, '_list'):
# Get headers from the _list where middleware adds them
all_headers_list = headers.__dict__.get('_list', [])
# Convert to dict for easier processing
all_headers = {}
for k, v in all_headers_list:
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
# Check if headers are missing and try to get them from request.state
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
# Add missing headers from request.state
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
# Add is_demo flag if this is a demo session
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
all_headers['x-is-demo'] = 'true'
# Filter out hop-by-hop headers
filtered_headers = {
k: v for k, v in all_headers.items()
if k.lower() not in hop_by_hop_headers
}
elif hasattr(headers, 'raw'):
# FastAPI/Starlette Headers object - use raw to get all headers
filtered_headers = {
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
for k, v in headers.raw
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
}
"""Prepare headers for forwarding using unified HeaderManager"""
# Use unified HeaderManager to get all headers
if request:
all_headers = header_manager.get_all_headers_for_proxy(request)
else:
# Already a dict
filtered_headers = {
k: v for k, v in headers.items()
if k.lower() not in hop_by_hop_headers
}
# Add gateway identifier
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
filtered_headers['X-Gateway-Version'] = '1.0.0'
return filtered_headers
# Fallback: convert headers to dict manually
all_headers = {}
if hasattr(headers, '_list'):
for k, v in headers.__dict__.get('_list', []):
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
elif hasattr(headers, 'raw'):
for k, v in headers.raw:
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
else:
# Headers is already a dict
all_headers = dict(headers)
return all_headers
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
"""Prepare response headers"""