New enterprise feature
This commit is contained in:
@@ -108,11 +108,19 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
user_context = request.state.user
|
||||
tenant_id = user_context.get("tenant_id") or getattr(request.state, "tenant_id", None)
|
||||
|
||||
# Inject subscription tier for demo sessions - always enterprise tier for full feature access
|
||||
user_context["subscription_tier"] = "enterprise"
|
||||
logger.debug(f"Demo session subscription tier set to enterprise", tenant_id=tenant_id)
|
||||
# For demo sessions, get the actual subscription tier from the tenant service
|
||||
# instead of always defaulting to enterprise
|
||||
if not user_context.get("subscription_tier"):
|
||||
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
|
||||
if subscription_tier:
|
||||
user_context["subscription_tier"] = subscription_tier
|
||||
else:
|
||||
# Fallback to enterprise for demo if no tier is found
|
||||
user_context["subscription_tier"] = "enterprise"
|
||||
|
||||
self._inject_context_headers(request, user_context, tenant_id)
|
||||
logger.debug(f"Demo session subscription tier set to {user_context['subscription_tier']}", tenant_id=tenant_id)
|
||||
|
||||
await self._inject_context_headers(request, user_context, tenant_id)
|
||||
return await call_next(request)
|
||||
|
||||
# ✅ STEP 1: Extract and validate JWT token
|
||||
@@ -159,14 +167,24 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
if subscription_tier:
|
||||
user_context["subscription_tier"] = subscription_tier
|
||||
|
||||
# Check hierarchical access to determine access type and permissions
|
||||
hierarchical_access = await tenant_access_manager.verify_hierarchical_access(
|
||||
user_context["user_id"],
|
||||
tenant_id
|
||||
)
|
||||
|
||||
# Set tenant context in request state
|
||||
request.state.tenant_id = tenant_id
|
||||
request.state.tenant_verified = True
|
||||
request.state.tenant_access_type = hierarchical_access.get("access_type", "direct")
|
||||
request.state.can_view_children = hierarchical_access.get("can_view_children", False)
|
||||
|
||||
logger.debug(f"Tenant access verified",
|
||||
user_id=user_context["user_id"],
|
||||
tenant_id=tenant_id,
|
||||
subscription_tier=subscription_tier,
|
||||
access_type=hierarchical_access.get("access_type"),
|
||||
can_view_children=hierarchical_access.get("can_view_children"),
|
||||
path=request.url.path)
|
||||
|
||||
# ✅ STEP 5: Inject user context into request
|
||||
@@ -174,7 +192,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
request.state.authenticated = True
|
||||
|
||||
# ✅ STEP 6: Add context headers for downstream services
|
||||
self._inject_context_headers(request, user_context, tenant_id)
|
||||
await self._inject_context_headers(request, user_context, tenant_id)
|
||||
|
||||
logger.debug(f"Authenticated request",
|
||||
user_email=user_context['email'],
|
||||
@@ -402,7 +420,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache user context: {e}")
|
||||
|
||||
def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
|
||||
async def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
|
||||
"""
|
||||
Inject user and tenant context headers for downstream services
|
||||
ENHANCED: Added logging to verify header injection
|
||||
@@ -456,6 +474,45 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
b"x-subscription-tier", subscription_tier.encode()
|
||||
))
|
||||
|
||||
# Add is_demo flag for demo sessions
|
||||
is_demo = user_context.get("is_demo", False)
|
||||
if is_demo:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-is-demo", b"true"
|
||||
))
|
||||
|
||||
# Add hierarchical access headers if tenant context exists
|
||||
if tenant_id:
|
||||
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
|
||||
can_view_children = getattr(request.state, 'can_view_children', False)
|
||||
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-tenant-access-type", tenant_access_type.encode()
|
||||
))
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-can-view-children", str(can_view_children).encode()
|
||||
))
|
||||
|
||||
# If this is hierarchical access, include parent tenant ID
|
||||
# Get parent tenant ID from the auth service if available
|
||||
try:
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/hierarchy",
|
||||
headers={"Authorization": request.headers.get("Authorization", "")}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
hierarchy_data = response.json()
|
||||
parent_tenant_id = hierarchy_data.get("parent_tenant_id")
|
||||
if parent_tenant_id:
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-parent-tenant-id", parent_tenant_id.encode()
|
||||
))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get parent tenant ID: {e}")
|
||||
pass
|
||||
|
||||
# Add gateway identification
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-forwarded-by", b"bakery-gateway"
|
||||
|
||||
@@ -88,11 +88,6 @@ class SubscriptionMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process the request and check subscription requirements"""
|
||||
|
||||
# Skip subscription check for demo sessions - they get enterprise tier
|
||||
if hasattr(request.state, "is_demo_session") and request.state.is_demo_session:
|
||||
logger.debug("Skipping subscription check for demo session", path=request.url.path)
|
||||
return await call_next(request)
|
||||
|
||||
# Skip subscription check for certain routes
|
||||
if self._should_skip_subscription_check(request):
|
||||
return await call_next(request)
|
||||
|
||||
Reference in New Issue
Block a user