Files
bakery-ia/gateway/app/middleware/auth.py

701 lines
30 KiB
Python
Raw Normal View History

2025-07-26 18:46:52 +02:00
# gateway/app/middleware/auth.py
2025-07-19 17:49:03 +02:00
"""
2025-07-26 18:46:52 +02:00
Enhanced Authentication Middleware for API Gateway with Tenant Access Control
2025-07-26 20:04:24 +02:00
FIXED VERSION - Proper JWT verification and token structure handling
2025-07-19 17:49:03 +02:00
"""
import structlog
2026-01-10 21:45:37 +01:00
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
2025-07-17 19:54:04 +02:00
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
2025-07-19 17:49:03 +02:00
from typing import Optional, Dict, Any
2025-07-26 20:04:24 +02:00
import httpx
import json
from app.core.config import settings
from shared.auth.jwt_handler import JWTHandler
2025-07-26 18:46:52 +02:00
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
2025-07-19 17:49:03 +02:00
logger = structlog.get_logger()
2025-07-26 20:04:24 +02:00
# JWT handler for local token validation - using SAME configuration as auth service
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
# Routes that don't require authentication
PUBLIC_ROUTES = [
"/health",
2025-09-25 14:30:47 +02:00
"/metrics",
"/docs",
"/redoc",
"/openapi.json",
"/api/v1/auth/login",
"/api/v1/auth/register",
2025-07-18 16:48:49 +02:00
"/api/v1/auth/refresh",
2025-07-22 17:01:12 +02:00
"/api/v1/auth/verify",
2025-09-25 14:30:47 +02:00
"/api/v1/nominatim/search",
2025-10-03 14:09:34 +02:00
"/api/v1/plans",
"/api/v1/demo/accounts",
2025-10-07 07:15:07 +02:00
"/api/v1/demo/sessions"
]
2025-12-27 21:30:42 +01:00
# Routes accessible with demo session (no JWT required, just demo session header)
DEMO_ACCESSIBLE_ROUTES = [
"/api/v1/tenants/", # All tenant endpoints accessible in demo mode
]
2025-07-17 19:54:04 +02:00
class AuthMiddleware(BaseHTTPMiddleware):
2025-07-19 17:49:03 +02:00
"""
2025-07-26 18:46:52 +02:00
Enhanced Authentication Middleware with Tenant Access Control
2025-07-19 17:49:03 +02:00
"""
2025-07-19 17:49:03 +02:00
def __init__(self, app, redis_client=None):
super().__init__(app)
self.redis_client = redis_client # For caching and rate limiting
2025-07-17 19:54:04 +02:00
async def dispatch(self, request: Request, call_next) -> Response:
2025-07-26 18:46:52 +02:00
"""Process request with enhanced authentication and tenant access control"""
2025-07-22 23:01:34 +02:00
# Skip authentication for OPTIONS requests (CORS preflight)
if request.method == "OPTIONS":
return await call_next(request)
2025-07-26 18:46:52 +02:00
2026-01-10 21:45:37 +01:00
# SECURITY: Remove any incoming x-subscription-* headers
# These will be re-injected from verified JWT only
sanitized_headers = [
(k, v) for k, v in request.headers.raw
if not k.decode().lower().startswith('x-subscription-')
and not k.decode().lower().startswith('x-user-')
and not k.decode().lower().startswith('x-tenant-')
]
request.headers.__dict__["_list"] = sanitized_headers
2025-07-19 17:49:03 +02:00
# Skip authentication for public routes
2025-07-17 19:54:04 +02:00
if self._is_public_route(request.url.path):
return await call_next(request)
2025-07-26 18:46:52 +02:00
2025-10-19 19:22:37 +02:00
# ✅ Check if demo middleware already set user context OR check query param for SSE
2025-10-03 14:09:34 +02:00
demo_session_header = request.headers.get("X-Demo-Session-Id")
2025-10-19 19:22:37 +02:00
demo_session_query = request.query_params.get("demo_session_id") # For SSE endpoint
logger.info(f"Auth check - path: {request.url.path}, demo_header: {demo_session_header}, demo_query: {demo_session_query}, has_demo_state: {hasattr(request.state, 'is_demo_session')}")
# For SSE endpoint with demo_session_id in query params, validate it here
if request.url.path == "/api/events" and demo_session_query and not hasattr(request.state, "is_demo_session"):
logger.info(f"SSE endpoint with demo_session_id query param: {demo_session_query}")
# Validate demo session via demo-session service
import httpx
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://demo-session-service:8000/api/v1/demo/sessions/{demo_session_query}",
headers={"X-Internal-API-Key": "dev-internal-key-change-in-production"}
)
if response.status_code == 200:
session_data = response.json()
# Set demo session context
request.state.is_demo_session = True
request.state.user = {
"user_id": f"demo-user-{demo_session_query}",
"email": f"demo-{demo_session_query}@demo.local",
"tenant_id": session_data.get("virtual_tenant_id"),
"demo_session_id": demo_session_query,
}
request.state.tenant_id = session_data.get("virtual_tenant_id")
logger.info(f"✅ Demo session validated for SSE: {demo_session_query}")
else:
logger.warning(f"Invalid demo session for SSE: {demo_session_query}")
return JSONResponse(
status_code=401,
content={"detail": "Invalid demo session"}
)
except Exception as e:
logger.error(f"Failed to validate demo session for SSE: {e}")
return JSONResponse(
status_code=503,
content={"detail": "Demo session service unavailable"}
)
2025-10-03 14:09:34 +02:00
if hasattr(request.state, "is_demo_session") and request.state.is_demo_session:
if hasattr(request.state, "user") and request.state.user:
logger.info(f"✅ Demo session authenticated for route: {request.url.path}")
# Demo middleware already validated and set user context
# But we still need to inject context headers for downstream services
user_context = request.state.user
tenant_id = user_context.get("tenant_id") or getattr(request.state, "tenant_id", None)
2025-10-17 07:31:14 +02:00
2025-11-30 09:12:40 +01:00
# For demo sessions, get the actual subscription tier from the tenant service
# instead of always defaulting to enterprise
if not user_context.get("subscription_tier"):
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
if subscription_tier:
user_context["subscription_tier"] = subscription_tier
else:
# Fallback to enterprise for demo if no tier is found
user_context["subscription_tier"] = "enterprise"
logger.debug(f"Demo session subscription tier set to {user_context['subscription_tier']}", tenant_id=tenant_id)
2025-10-17 07:31:14 +02:00
2025-11-30 09:12:40 +01:00
await self._inject_context_headers(request, user_context, tenant_id)
2025-10-03 14:09:34 +02:00
return await call_next(request)
2025-07-26 18:46:52 +02:00
# ✅ STEP 1: Extract and validate JWT token
2025-07-17 19:54:04 +02:00
token = self._extract_token(request)
if not token:
2025-10-03 14:09:34 +02:00
logger.warning(f"❌ Missing token for protected route: {request.url.path}, demo_header: {demo_session_header}")
2025-07-17 19:54:04 +02:00
return JSONResponse(
status_code=401,
content={"detail": "Authentication required"}
)
2025-07-26 18:46:52 +02:00
# ✅ STEP 2: Verify token and get user context
user_context = await self._verify_token(token, request)
2025-07-19 17:49:03 +02:00
if not user_context:
logger.warning(f"Invalid token for route: {request.url.path}")
2025-07-17 19:54:04 +02:00
return JSONResponse(
status_code=401,
2025-07-26 20:04:24 +02:00
content={"detail": "User not authenticated"}
2025-07-17 19:54:04 +02:00
)
2025-07-26 18:46:52 +02:00
# ✅ STEP 3: Extract tenant context from URL using shared utility
tenant_id = extract_tenant_id_from_path(request.url.path)
# ✅ STEP 4: Verify tenant access if this is a tenant-scoped route
if tenant_id and is_tenant_scoped_path(request.url.path):
# Use TenantAccessManager for gateway-level verification with caching
if self.redis_client and tenant_access_manager.redis_client is None:
tenant_access_manager.redis_client = self.redis_client
2025-07-26 20:04:24 +02:00
has_access = await tenant_access_manager.verify_basic_tenant_access(
2025-07-26 18:46:52 +02:00
user_context["user_id"],
tenant_id
)
2025-07-19 17:49:03 +02:00
if not has_access:
logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}")
return JSONResponse(
status_code=403,
2025-07-26 18:46:52 +02:00
content={"detail": f"Access denied to tenant {tenant_id}"}
2025-07-19 17:49:03 +02:00
)
2025-07-26 18:46:52 +02:00
2025-10-06 15:27:01 +02:00
# Get tenant subscription tier and inject into user context
2026-01-10 21:45:37 +01:00
# NEW: Use JWT data if available, skip HTTP call
if user_context.get("subscription_from_jwt"):
subscription_tier = user_context.get("subscription_tier")
logger.debug("Using subscription tier from JWT", tier=subscription_tier)
else:
# Only for old tokens - remove after full rollout
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
2025-10-06 15:27:01 +02:00
if subscription_tier:
user_context["subscription_tier"] = subscription_tier
2025-11-30 09:12:40 +01:00
# Check hierarchical access to determine access type and permissions
hierarchical_access = await tenant_access_manager.verify_hierarchical_access(
user_context["user_id"],
tenant_id
)
2025-07-26 18:46:52 +02:00
# Set tenant context in request state
2025-07-19 17:49:03 +02:00
request.state.tenant_id = tenant_id
2025-07-26 18:46:52 +02:00
request.state.tenant_verified = True
2025-11-30 09:12:40 +01:00
request.state.tenant_access_type = hierarchical_access.get("access_type", "direct")
request.state.can_view_children = hierarchical_access.get("can_view_children", False)
2025-07-26 18:46:52 +02:00
logger.debug(f"Tenant access verified",
user_id=user_context["user_id"],
tenant_id=tenant_id,
2025-10-06 15:27:01 +02:00
subscription_tier=subscription_tier,
2025-11-30 09:12:40 +01:00
access_type=hierarchical_access.get("access_type"),
can_view_children=hierarchical_access.get("can_view_children"),
2025-07-26 18:46:52 +02:00
path=request.url.path)
# ✅ STEP 5: Inject user context into request
2025-07-19 17:49:03 +02:00
request.state.user = user_context
request.state.authenticated = True
2025-07-26 18:46:52 +02:00
# ✅ STEP 6: Add context headers for downstream services
2025-11-30 09:12:40 +01:00
await self._inject_context_headers(request, user_context, tenant_id)
2025-07-26 18:46:52 +02:00
logger.debug(f"Authenticated request",
user_email=user_context['email'],
tenant_id=tenant_id,
path=request.url.path)
# Process the request
response = await call_next(request)
# Add token expiry warning header if token is near expiry
if hasattr(request.state, 'token_near_expiry') and request.state.token_near_expiry:
response.headers["X-Token-Refresh-Suggested"] = "true"
return response
2025-07-17 19:54:04 +02:00
def _is_public_route(self, path: str) -> bool:
2025-07-19 17:49:03 +02:00
"""Check if route requires authentication"""
2025-07-17 19:54:04 +02:00
return any(path.startswith(route) for route in PUBLIC_ROUTES)
2025-07-17 19:54:04 +02:00
def _extract_token(self, request: Request) -> Optional[str]:
"""
Extract JWT token from Authorization header or query params for SSE.
For SSE endpoints (/api/events), browsers' EventSource API cannot send
custom headers, so we must accept token as query parameter.
For all other routes, token must be in Authorization header (more secure).
Security note: Query param tokens are logged. Use short expiry and filter logs.
"""
# SSE endpoint exception: token in query param (EventSource API limitation)
if request.url.path == "/api/events":
token = request.query_params.get("token")
if token:
logger.debug("Token extracted from query param for SSE endpoint")
return token
logger.warning("SSE request missing token in query param")
return None
# Standard authentication: Authorization header for all other routes
2025-07-17 19:54:04 +02:00
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.split(" ")[1]
2025-07-17 19:54:04 +02:00
return None
async def _verify_token(self, token: str, request: Request = None) -> Optional[Dict[str, Any]]:
2025-07-26 20:04:24 +02:00
"""
Verify JWT token with improved fallback strategy
FIXED: Better error handling and token structure validation
"""
2025-07-19 17:49:03 +02:00
2025-07-26 20:04:24 +02:00
# Strategy 1: Try local JWT validation first (fast)
2025-07-19 17:49:03 +02:00
try:
payload = jwt_handler.verify_token(token)
if payload and self._validate_token_payload(payload):
logger.debug("Token validated locally")
2026-01-10 21:45:37 +01:00
# NEW: Check token freshness for subscription changes (async)
if payload.get("tenant_id") and request:
try:
is_fresh = await self._verify_token_freshness(payload, payload["tenant_id"])
if not is_fresh:
logger.warning("Stale token detected - subscription changed since token was issued",
user_id=payload.get("user_id"),
tenant_id=payload.get("tenant_id"))
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token is stale - subscription has changed"
)
except Exception as e:
logger.warning("Token freshness check failed, allowing token", error=str(e))
# Allow token if check fails (fail open for availability)
# Check if token is near expiry and set flag for response header
if request:
import time
exp_time = payload.get("exp", 0)
current_time = time.time()
time_until_expiry = exp_time - current_time
if time_until_expiry < 300: # 5 minutes
request.state.token_near_expiry = True
2025-07-26 20:04:24 +02:00
# Convert JWT payload to user context format
return self._jwt_payload_to_user_context(payload)
2025-07-19 17:49:03 +02:00
except Exception as e:
logger.debug(f"Local token validation failed: {e}")
2025-07-26 20:04:24 +02:00
# Strategy 2: Check cache for recently validated tokens
2025-07-19 17:49:03 +02:00
if self.redis_client:
try:
cached_user = await self._get_cached_user(token)
if cached_user:
logger.debug("Token found in cache")
return cached_user
except Exception as e:
logger.warning(f"Cache lookup failed: {e}")
2025-07-26 20:04:24 +02:00
# Strategy 3: Verify with auth service (authoritative)
2025-07-19 17:49:03 +02:00
try:
user_context = await self._verify_with_auth_service(token)
if user_context:
# Cache successful validation
if self.redis_client:
await self._cache_user(token, user_context)
logger.debug("Token validated by auth service")
return user_context
except Exception as e:
logger.error(f"Auth service validation failed: {e}")
2025-07-26 18:46:52 +02:00
2025-07-19 17:49:03 +02:00
return None
def _validate_token_payload(self, payload: Dict[str, Any]) -> bool:
2025-07-26 20:04:24 +02:00
"""
Validate JWT payload has required fields
FIXED: Updated to match actual token structure from auth service
"""
required_fields = ["user_id", "email", "exp", "type"]
missing_fields = [field for field in required_fields if field not in payload]
2025-07-26 20:04:24 +02:00
if missing_fields:
logger.warning(f"Token payload missing fields: {missing_fields}")
return False
2025-07-26 20:04:24 +02:00
# Validate token type
2025-07-27 16:29:53 +02:00
token_type = payload.get("type")
if token_type not in ["access", "service"]:
2025-07-26 20:04:24 +02:00
logger.warning(f"Invalid token type: {payload.get('type')}")
return False
2025-07-27 16:29:53 +02:00
# Check if token is near expiry (within 5 minutes) and log warning
import time
exp_time = payload.get("exp", 0)
current_time = time.time()
time_until_expiry = exp_time - current_time
if time_until_expiry < 300: # 5 minutes
logger.warning(f"Token expires in {int(time_until_expiry)} seconds for user {payload.get('email')}")
2026-01-10 21:45:37 +01:00
# NEW: Check token freshness for subscription changes
if payload.get("tenant_id"):
try:
# Note: We can't await here because this is a sync function
# Token freshness will be checked in the async dispatch method
# For now, just log that we would check freshness
logger.debug("Token freshness check would be performed in async context",
tenant_id=payload.get("tenant_id"))
except Exception as e:
logger.warning("Token freshness check setup failed", error=str(e))
return True
def _validate_jwt_integrity(self, payload: Dict[str, Any]) -> bool:
"""
Validate JWT payload integrity beyond signature verification.
Prevents edge cases where payload might be malformed.
"""
# Required fields must exist
required_fields = ["user_id", "email", "exp", "iat", "iss"]
if not all(field in payload for field in required_fields):
logger.warning("JWT missing required fields", missing=[f for f in required_fields if f not in payload])
return False
# Issuer must be our auth service
if payload.get("iss") != "bakery-auth":
logger.warning("JWT has invalid issuer", issuer=payload.get("iss"))
return False
# Token type must be valid
if payload.get("type") not in ["access", "service"]:
logger.warning("JWT has invalid type", token_type=payload.get("type"))
return False
# Subscription tier must be valid if present
valid_tiers = ["starter", "professional", "enterprise"]
if payload.get("subscription"):
tier = payload["subscription"].get("tier", "").lower()
if tier and tier not in valid_tiers:
logger.warning("JWT has invalid subscription tier", tier=tier)
return False
return True
async def _verify_token_freshness(self, payload: Dict[str, Any], tenant_id: str) -> bool:
"""
Verify token was issued after the last subscription change.
Prevents use of stale tokens with old subscription data.
"""
if not self.redis_client:
return True # Skip check if no Redis
try:
subscription_changed_at = await self.redis_client.get(
f"tenant:{tenant_id}:subscription_changed_at"
)
if subscription_changed_at:
changed_timestamp = float(subscription_changed_at)
token_issued_at = payload.get("iat", 0)
if token_issued_at < changed_timestamp:
logger.warning(
"Token issued before subscription change",
token_iat=token_issued_at,
subscription_changed=changed_timestamp,
tenant_id=tenant_id
)
return False # Token is stale
except Exception as e:
logger.warning("Failed to check token freshness", error=str(e))
2025-07-26 20:04:24 +02:00
return True
def _jwt_payload_to_user_context(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert JWT payload to user context format
FIXED: Proper mapping between JWT structure and user context
"""
2026-01-10 21:45:37 +01:00
# NEW: Validate JWT integrity before processing
if not self._validate_jwt_integrity(payload):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT payload"
)
2025-07-27 16:29:53 +02:00
base_context = {
2025-07-26 20:04:24 +02:00
"user_id": payload["user_id"],
"email": payload["email"],
"exp": payload["exp"],
2025-08-03 00:16:31 +02:00
"valid": True,
"role": payload.get("role", "user"),
2025-07-26 20:04:24 +02:00
}
2025-07-27 16:29:53 +02:00
2026-01-10 21:45:37 +01:00
# NEW: Extract subscription from JWT
if payload.get("tenant_id"):
base_context["tenant_id"] = payload["tenant_id"]
base_context["tenant_role"] = payload.get("tenant_role", "member")
if payload.get("subscription"):
sub = payload["subscription"]
base_context["subscription_tier"] = sub.get("tier", "starter")
base_context["subscription_status"] = sub.get("status", "active")
base_context["subscription_from_jwt"] = True # Flag to skip HTTP
if payload.get("tenant_access"):
base_context["tenant_access"] = payload["tenant_access"]
2025-07-27 16:29:53 +02:00
if payload.get("service"):
2025-08-02 21:56:25 +02:00
service_name = payload["service"]
base_context["service"] = service_name
2025-07-27 16:29:53 +02:00
base_context["type"] = "service"
2025-08-02 23:05:18 +02:00
base_context["role"] = "admin"
2025-08-02 21:56:25 +02:00
base_context["user_id"] = f"{service_name}-service"
base_context["email"] = f"{service_name}-service@internal"
2025-07-27 16:29:53 +02:00
logger.debug(f"Service authentication: {payload['service']}")
return base_context
2025-07-19 17:49:03 +02:00
async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]:
2025-07-26 20:04:24 +02:00
"""
Verify token with auth service
FIXED: Improved error handling and response parsing
"""
2025-07-17 19:54:04 +02:00
try:
2025-07-26 20:04:24 +02:00
async with httpx.AsyncClient(timeout=5.0) as client:
2025-07-17 19:54:04 +02:00
response = await client.post(
2025-07-18 16:48:49 +02:00
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
2025-07-17 19:54:04 +02:00
headers={"Authorization": f"Bearer {token}"}
)
2025-07-17 19:54:04 +02:00
if response.status_code == 200:
2025-07-26 20:04:24 +02:00
auth_response = response.json()
# Validate auth service response structure
if auth_response.get("valid") and auth_response.get("user_id"):
return {
"user_id": auth_response["user_id"],
"email": auth_response["email"],
"exp": auth_response.get("exp"),
"valid": True
}
else:
logger.warning(f"Auth service returned invalid response: {auth_response}")
return None
2025-07-17 19:54:04 +02:00
else:
2025-07-26 20:04:24 +02:00
logger.warning(f"Auth service returned {response.status_code}: {response.text}")
2025-07-17 19:54:04 +02:00
return None
2025-07-26 20:04:24 +02:00
except httpx.TimeoutException:
logger.error("Auth service timeout during token verification")
return None
2025-07-19 17:49:03 +02:00
except Exception as e:
logger.error(f"Auth service error: {e}")
return None
async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]:
2025-07-26 20:04:24 +02:00
"""
Get user context from cache
FIXED: Better error handling and JSON parsing
"""
2025-07-19 17:49:03 +02:00
if not self.redis_client:
return None
2025-07-26 20:04:24 +02:00
cache_key = f"auth:token:{hash(token) % 1000000}" # Use modulo for shorter keys
2025-07-26 18:46:52 +02:00
try:
cached_data = await self.redis_client.get(cache_key)
if cached_data:
if isinstance(cached_data, bytes):
cached_data = cached_data.decode()
return json.loads(cached_data)
2025-07-26 20:04:24 +02:00
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse cached user data: {e}")
2025-07-26 18:46:52 +02:00
except Exception as e:
2025-07-26 20:04:24 +02:00
logger.warning(f"Cache lookup error: {e}")
2025-07-19 17:49:03 +02:00
return None
2025-07-26 20:04:24 +02:00
async def _cache_user(self, token: str, user_context: Dict[str, Any]) -> None:
"""
Cache user context
FIXED: Better error handling and expiration
"""
2025-07-19 17:49:03 +02:00
if not self.redis_client:
return
2025-07-26 20:04:24 +02:00
cache_key = f"auth:token:{hash(token) % 1000000}"
2025-07-26 18:46:52 +02:00
try:
2025-07-26 20:04:24 +02:00
# Cache for 5 minutes (shorter than token expiry)
await self.redis_client.setex(
cache_key,
300, # 5 minutes
json.dumps(user_context)
)
2025-07-26 18:46:52 +02:00
except Exception as e:
2025-07-26 20:04:24 +02:00
logger.warning(f"Failed to cache user context: {e}")
2025-07-19 17:49:03 +02:00
2025-11-30 09:12:40 +01:00
async def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
2025-07-26 20:04:24 +02:00
"""
Inject user and tenant context headers for downstream services
ENHANCED: Added logging to verify header injection
2025-07-26 20:04:24 +02:00
"""
# Log what we're injecting for debugging
logger.debug(
"Injecting context headers",
user_id=user_context.get("user_id"),
user_type=user_context.get("type", ""),
service_name=user_context.get("service", ""),
role=user_context.get("role", ""),
tenant_id=tenant_id,
path=request.url.path
)
2025-07-26 20:04:24 +02:00
# Add user context headers
request.headers.__dict__["_list"].append((
b"x-user-id", user_context["user_id"].encode()
))
request.headers.__dict__["_list"].append((
b"x-user-email", user_context["email"].encode()
))
2025-08-02 23:29:18 +02:00
user_role = user_context.get("role", "user")
request.headers.__dict__["_list"].append((
b"x-user-role", user_role.encode()
))
2025-08-02 23:29:18 +02:00
user_type = user_context.get("type", "")
if user_type:
request.headers.__dict__["_list"].append((
b"x-user-type", user_type.encode()
))
2025-08-02 23:29:18 +02:00
service_name = user_context.get("service", "")
if service_name:
request.headers.__dict__["_list"].append((
b"x-service-name", service_name.encode()
))
2025-07-26 20:04:24 +02:00
# Add tenant context if available
2025-07-26 18:46:52 +02:00
if tenant_id:
2025-07-26 20:04:24 +02:00
request.headers.__dict__["_list"].append((
b"x-tenant-id", tenant_id.encode()
2025-08-02 23:29:18 +02:00
))
2025-10-06 15:27:01 +02:00
# Add subscription tier if available
subscription_tier = user_context.get("subscription_tier", "")
if subscription_tier:
request.headers.__dict__["_list"].append((
b"x-subscription-tier", subscription_tier.encode()
))
2025-11-30 09:12:40 +01:00
# Add is_demo flag for demo sessions
is_demo = user_context.get("is_demo", False)
if is_demo:
request.headers.__dict__["_list"].append((
b"x-is-demo", b"true"
))
2025-11-30 16:29:38 +01:00
# Add demo session context headers for backend services
demo_session_id = user_context.get("demo_session_id", "")
if demo_session_id:
request.headers.__dict__["_list"].append((
b"x-demo-session-id", demo_session_id.encode()
))
demo_account_type = user_context.get("demo_account_type", "")
if demo_account_type:
request.headers.__dict__["_list"].append((
b"x-demo-account-type", demo_account_type.encode()
))
2025-11-30 09:12:40 +01:00
# Add hierarchical access headers if tenant context exists
if tenant_id:
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
can_view_children = getattr(request.state, 'can_view_children', False)
request.headers.__dict__["_list"].append((
b"x-tenant-access-type", tenant_access_type.encode()
))
request.headers.__dict__["_list"].append((
b"x-can-view-children", str(can_view_children).encode()
))
# If this is hierarchical access, include parent tenant ID
# Get parent tenant ID from the auth service if available
try:
import httpx
async with httpx.AsyncClient(timeout=3.0) as client:
response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/hierarchy",
headers={"Authorization": request.headers.get("Authorization", "")}
)
if response.status_code == 200:
hierarchy_data = response.json()
parent_tenant_id = hierarchy_data.get("parent_tenant_id")
if parent_tenant_id:
request.headers.__dict__["_list"].append((
b"x-parent-tenant-id", parent_tenant_id.encode()
))
except Exception as e:
logger.warning(f"Failed to get parent tenant ID: {e}")
pass
2025-07-26 20:04:24 +02:00
# Add gateway identification
request.headers.__dict__["_list"].append((
b"x-forwarded-by", b"bakery-gateway"
2025-10-06 15:27:01 +02:00
))
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
"""
2025-10-29 06:58:05 +01:00
Get tenant subscription tier using fast cached endpoint
2025-10-06 15:27:01 +02:00
Args:
tenant_id: Tenant ID
request: FastAPI request for headers
Returns:
Subscription tier string or None
"""
try:
2025-10-29 06:58:05 +01:00
# Use fast cached subscription tier endpoint (has its own Redis caching)
async with httpx.AsyncClient(timeout=3.0) as client:
2025-10-06 15:27:01 +02:00
headers = {"Authorization": request.headers.get("Authorization", "")}
response = await client.get(
2025-10-29 06:58:05 +01:00
f"{settings.TENANT_SERVICE_URL}/api/v1/subscriptions/{tenant_id}/tier",
2025-10-06 15:27:01 +02:00
headers=headers
)
if response.status_code == 200:
2025-10-29 06:58:05 +01:00
tier_data = response.json()
subscription_tier = tier_data.get("tier", "starter")
logger.debug("Subscription tier from cached endpoint",
tenant_id=tenant_id,
tier=subscription_tier,
cached=tier_data.get("cached", False))
2025-10-06 15:27:01 +02:00
return subscription_tier
else:
logger.warning(f"Failed to get tenant subscription tier: {response.status_code}")
2025-10-29 06:58:05 +01:00
return "starter" # Default to starter
2025-10-06 15:27:01 +02:00
except Exception as e:
logger.error(f"Error getting tenant subscription tier: {e}")
2026-01-10 21:45:37 +01:00
return "starter" # Default to starter on error