Add improvements 2
This commit is contained in:
@@ -14,6 +14,7 @@ import httpx
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
|
||||
|
||||
@@ -60,15 +61,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
|
||||
# SECURITY: Remove any incoming x-subscription-* headers
|
||||
# These will be re-injected from verified JWT only
|
||||
sanitized_headers = [
|
||||
(k, v) for k, v in request.headers.raw
|
||||
if not k.decode().lower().startswith('x-subscription-')
|
||||
and not k.decode().lower().startswith('x-user-')
|
||||
and not k.decode().lower().startswith('x-tenant-')
|
||||
]
|
||||
request.headers.__dict__["_list"] = sanitized_headers
|
||||
# SECURITY: Remove any incoming sensitive headers using HeaderManager
|
||||
header_manager.sanitize_incoming_headers(request)
|
||||
|
||||
# Skip authentication for public routes
|
||||
if self._is_public_route(request.url.path):
|
||||
@@ -573,109 +567,13 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
async def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
|
||||
"""
|
||||
Inject user and tenant context headers for downstream services
|
||||
ENHANCED: Added logging to verify header injection
|
||||
Inject user and tenant context headers for downstream services using unified HeaderManager
|
||||
"""
|
||||
# Enhanced logging for debugging
|
||||
logger.info(
|
||||
"🔧 Injecting context headers",
|
||||
user_id=user_context.get("user_id"),
|
||||
user_type=user_context.get("type", ""),
|
||||
service_name=user_context.get("service", ""),
|
||||
role=user_context.get("role", ""),
|
||||
tenant_id=tenant_id,
|
||||
is_demo=user_context.get("is_demo", False),
|
||||
demo_session_id=user_context.get("demo_session_id", ""),
|
||||
path=request.url.path
|
||||
)
|
||||
|
||||
# Add user context headers
|
||||
logger.debug(f"DEBUG: Injecting headers for user: {user_context.get('user_id')}, is_demo: {user_context.get('is_demo', False)}")
|
||||
logger.debug(f"DEBUG: request.headers object id: {id(request.headers)}, _list id: {id(request.headers.__dict__.get('_list', []))}")
|
||||
# Use unified HeaderManager for consistent header injection
|
||||
injected_headers = header_manager.inject_context_headers(request, user_context, tenant_id)
|
||||
|
||||
# Store headers in request.state for cross-middleware access
|
||||
request.state.injected_headers = {
|
||||
"x-user-id": user_context["user_id"],
|
||||
"x-user-email": user_context["email"],
|
||||
"x-user-role": user_context.get("role", "user")
|
||||
}
|
||||
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-id", user_context["user_id"].encode()
|
||||
))
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-email", user_context["email"].encode()
|
||||
))
|
||||
|
||||
user_role = user_context.get("role", "user")
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-role", user_role.encode()
|
||||
))
|
||||
|
||||
user_type = user_context.get("type", "")
|
||||
if user_type:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-type", user_type.encode()
|
||||
))
|
||||
|
||||
service_name = user_context.get("service", "")
|
||||
if service_name:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-service-name", service_name.encode()
|
||||
))
|
||||
|
||||
# Add tenant context if available
|
||||
if tenant_id:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-tenant-id", tenant_id.encode()
|
||||
))
|
||||
|
||||
# Add subscription tier if available
|
||||
subscription_tier = user_context.get("subscription_tier", "")
|
||||
if subscription_tier:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-subscription-tier", subscription_tier.encode()
|
||||
))
|
||||
|
||||
# Add is_demo flag for demo sessions
|
||||
is_demo = user_context.get("is_demo", False)
|
||||
logger.debug(f"DEBUG: is_demo value: {is_demo}, type: {type(is_demo)}")
|
||||
if is_demo:
|
||||
logger.info(f"🎭 Adding demo session headers",
|
||||
demo_session_id=user_context.get("demo_session_id", ""),
|
||||
demo_account_type=user_context.get("demo_account_type", ""),
|
||||
path=request.url.path)
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-is-demo", b"true"
|
||||
))
|
||||
else:
|
||||
logger.debug(f"DEBUG: Not adding demo headers because is_demo is: {is_demo}")
|
||||
|
||||
# Add demo session context headers for backend services
|
||||
demo_session_id = user_context.get("demo_session_id", "")
|
||||
if demo_session_id:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-demo-session-id", demo_session_id.encode()
|
||||
))
|
||||
|
||||
demo_account_type = user_context.get("demo_account_type", "")
|
||||
if demo_account_type:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-demo-account-type", demo_account_type.encode()
|
||||
))
|
||||
|
||||
# Add hierarchical access headers if tenant context exists
|
||||
if tenant_id:
|
||||
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
|
||||
can_view_children = getattr(request.state, 'can_view_children', False)
|
||||
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-tenant-access-type", tenant_access_type.encode()
|
||||
))
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-can-view-children", str(can_view_children).encode()
|
||||
))
|
||||
|
||||
# If this is hierarchical access, include parent tenant ID
|
||||
# Get parent tenant ID from the auth service if available
|
||||
try:
|
||||
@@ -689,17 +587,16 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
hierarchy_data = response.json()
|
||||
parent_tenant_id = hierarchy_data.get("parent_tenant_id")
|
||||
if parent_tenant_id:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-parent-tenant-id", parent_tenant_id.encode()
|
||||
))
|
||||
# Add parent tenant ID using HeaderManager for consistency
|
||||
header_name = header_manager.STANDARD_HEADERS['parent_tenant_id']
|
||||
header_value = str(parent_tenant_id)
|
||||
header_manager.add_header_for_middleware(request, header_name, header_value)
|
||||
logger.info(f"Added parent tenant ID header: {parent_tenant_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get parent tenant ID: {e}")
|
||||
pass
|
||||
|
||||
# Add gateway identification
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-forwarded-by", b"bakery-gateway"
|
||||
))
|
||||
|
||||
return injected_headers
|
||||
|
||||
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user