Add improvements
This commit is contained in:
@@ -168,6 +168,228 @@ The architecture implements **defense-in-depth** with multiple validation layers
|
||||
- Gateway validates access to requested tenant
|
||||
- Supports hierarchical tenant access patterns
|
||||
|
||||
## JWT Service Token Authentication
|
||||
|
||||
### Overview
|
||||
The Gateway now supports **JWT service tokens** for secure service-to-service (S2S) communication. This replaces the deprecated internal API key system with a unified JWT-based authentication mechanism for both user and service requests.
|
||||
|
||||
### Service Token Support
|
||||
|
||||
**User Tokens** (frontend/API consumers):
|
||||
- `type: "access"` - Regular user authentication
|
||||
- Contains user ID, email, tenant membership, subscription data
|
||||
- Expires in 15-30 minutes
|
||||
- Validated and cached by gateway
|
||||
|
||||
**Service Tokens** (microservice communication):
|
||||
- `type: "service"` - Internal service authentication
|
||||
- Contains service name, admin role, optional tenant context
|
||||
- Expires in 1 hour
|
||||
- Automatically grants admin privileges to registered services
|
||||
|
||||
### Service Token Validation Flow
|
||||
|
||||
```
|
||||
┌─────────────────┐
|
||||
│ Calling Service│
|
||||
│ (e.g., demo) │
|
||||
└────────┬────────┘
|
||||
│
|
||||
│ 1. Create service token
|
||||
│ jwt_handler.create_service_token(
|
||||
│ service_name="demo-session",
|
||||
│ tenant_id=tenant_id
|
||||
│ )
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────┐
|
||||
│ HTTP Request to Gateway │
|
||||
│ -------------------------------- │
|
||||
│ POST /api/v1/tenant/clone │
|
||||
│ Headers: │
|
||||
│ Authorization: Bearer {service_token}│
|
||||
│ X-Service: demo-session-service │
|
||||
└────────┬────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ Gateway │
|
||||
│ Auth Middleware│
|
||||
└────────┬────────┘
|
||||
│
|
||||
│ 2. Extract and verify JWT
|
||||
│ jwt_handler.verify_token(token)
|
||||
│
|
||||
│ 3. Identify service token
|
||||
│ if token.type == "service":
|
||||
│
|
||||
│ 4. Check internal service registry
|
||||
│ if is_internal_service(service_name):
|
||||
│ grant_admin_access()
|
||||
│ skip_tenant_membership_check()
|
||||
│
|
||||
│ 5. Inject service context headers
|
||||
│ X-User-ID: demo-session-service
|
||||
│ X-User-Role: admin
|
||||
│ X-Service-Name: demo-session
|
||||
│
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ Target Service │
|
||||
│ (e.g., tenant) │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
### Internal Service Registry
|
||||
|
||||
The gateway uses a centralized registry of all 21 microservices:
|
||||
- **File**: `shared/config/base.py`
|
||||
- **Constant**: `INTERNAL_SERVICES` set
|
||||
- **Services**: gateway, auth, tenant, inventory, production, recipes, suppliers, orders, sales, procurement, pos, forecasting, training, ai-insights, orchestrator, notification, alert-processor, demo-session, external, distribution
|
||||
|
||||
**Automatic Privileges for Registered Services:**
|
||||
- Admin role granted automatically
|
||||
- Skip tenant membership validation
|
||||
- Access to all tenants within scope
|
||||
- Optimized database queries
|
||||
|
||||
### Service Token Payload
|
||||
|
||||
```json
|
||||
{
|
||||
"sub": "demo-session",
|
||||
"user_id": "demo-session-service",
|
||||
"email": "demo-session-service@internal",
|
||||
"service": "demo-session",
|
||||
"type": "service",
|
||||
"role": "admin",
|
||||
"tenant_id": "optional-tenant-uuid",
|
||||
"exp": 1735693199,
|
||||
"iat": 1735689599,
|
||||
"iss": "bakery-auth"
|
||||
}
|
||||
```
|
||||
|
||||
### Gateway Processing
|
||||
|
||||
#### Token Validation (`_validate_token_payload`)
|
||||
```python
|
||||
# Validates token type and required fields
|
||||
token_type = payload.get("type")
|
||||
if token_type not in ["access", "service"]:
|
||||
return False
|
||||
|
||||
# Service tokens with tenant context are valid
|
||||
if token_type == "service" and payload.get("tenant_id"):
|
||||
logger.debug("Service token with tenant context validated")
|
||||
```
|
||||
|
||||
#### User Context Extraction (`_jwt_payload_to_user_context`)
|
||||
```python
|
||||
# Detect service tokens
|
||||
if payload.get("service"):
|
||||
service_name = payload["service"]
|
||||
base_context = {
|
||||
"user_id": f"{service_name}-service",
|
||||
"email": f"{service_name}-service@internal",
|
||||
"service": service_name,
|
||||
"type": "service",
|
||||
"role": "admin", # Services get admin privileges
|
||||
"tenant_id": payload.get("tenant_id") # Optional tenant context
|
||||
}
|
||||
```
|
||||
|
||||
#### Tenant Access Control
|
||||
```python
|
||||
# Skip tenant access verification for service tokens
|
||||
if user_context.get("type") != "service":
|
||||
# Verify user has access to tenant
|
||||
has_access = await tenant_access_manager.verify_basic_tenant_access(
|
||||
user_context["user_id"], tenant_id
|
||||
)
|
||||
else:
|
||||
# Services have automatic access
|
||||
logger.debug(f"Service token granted access to tenant {tenant_id}")
|
||||
```
|
||||
|
||||
### Migration from Internal API Keys
|
||||
|
||||
**Old System (Deprecated - Removed in 2026-01):**
|
||||
```python
|
||||
# REMOVED - No longer supported
|
||||
headers = {
|
||||
"X-Internal-API-Key": "dev-internal-key-change-in-production"
|
||||
}
|
||||
```
|
||||
|
||||
**New System (Current):**
|
||||
```python
|
||||
# Gateway creates service tokens for internal calls
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
service_token = jwt_handler.create_service_token(service_name="gateway")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {service_token}"
|
||||
}
|
||||
```
|
||||
|
||||
### Security Benefits
|
||||
|
||||
1. **Token Expiration** - Service tokens expire (1 hour), preventing indefinite access
|
||||
2. **Signature Verification** - JWT signatures prevent token forgery and tampering
|
||||
3. **Tenant Scoping** - Service tokens can include tenant context for proper authorization
|
||||
4. **Unified Authentication** - Same JWT verification logic for user and service tokens
|
||||
5. **Audit Trail** - All service requests are authenticated and logged with service identity
|
||||
6. **No Shared Secrets** - Services don't share API keys; use shared JWT secret instead
|
||||
7. **Rotation Ready** - JWT secret can be rotated without code changes
|
||||
|
||||
### Performance Impact
|
||||
|
||||
- **Token Creation**: <1ms (in-memory JWT signing)
|
||||
- **Token Validation**: <1ms (in-memory JWT verification with shared secret)
|
||||
- **Caching**: Gateway caches validated service tokens for 5 minutes
|
||||
- **No Additional HTTP Calls**: Service auth happens locally at gateway
|
||||
|
||||
### Context Header Injection
|
||||
|
||||
When a service token is validated, the gateway injects these headers for downstream services:
|
||||
|
||||
```python
|
||||
X-User-ID: demo-session-service
|
||||
X-User-Email: demo-session-service@internal
|
||||
X-User-Role: admin
|
||||
X-User-Type: service
|
||||
X-Service-Name: demo-session
|
||||
X-Tenant-ID: {tenant_id} # If present in token
|
||||
```
|
||||
|
||||
### Gateway-to-Service Communication
|
||||
|
||||
The gateway itself creates service tokens when calling internal services:
|
||||
|
||||
#### Example: Demo Session Validation for SSE
|
||||
```python
|
||||
# gateway/app/middleware/auth.py
|
||||
service_token = jwt_handler.create_service_token(service_name="gateway")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"http://demo-session-service:8000/api/v1/demo/sessions/{session_id}",
|
||||
headers={"Authorization": f"Bearer {service_token}"}
|
||||
)
|
||||
```
|
||||
|
||||
### Shared JWT Secret
|
||||
|
||||
All services (including gateway) use the same JWT secret key:
|
||||
- **File**: `shared/config/base.py`
|
||||
- **Variable**: `JWT_SECRET_KEY`
|
||||
- **Default**: `usMHw9kQCQoyrc7wPmMi3bClr0lTY9wvzZmcTbADvL0=`
|
||||
- **Environment Override**: `JWT_SECRET_KEY` environment variable
|
||||
- **Production**: Must be set to a secure random value
|
||||
|
||||
## API Endpoints (Key Routes)
|
||||
|
||||
### Authentication Routes
|
||||
|
||||
@@ -82,13 +82,16 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
# For SSE endpoint with demo_session_id in query params, validate it here
|
||||
if request.url.path == "/api/events" and demo_session_query and not hasattr(request.state, "is_demo_session"):
|
||||
logger.info(f"SSE endpoint with demo_session_id query param: {demo_session_query}")
|
||||
# Validate demo session via demo-session service
|
||||
# Validate demo session via demo-session service using JWT service token
|
||||
import httpx
|
||||
try:
|
||||
# Create service token for gateway-to-demo-session communication
|
||||
service_token = jwt_handler.create_service_token(service_name="gateway")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"http://demo-session-service:8000/api/v1/demo/sessions/{demo_session_query}",
|
||||
headers={"X-Internal-API-Key": "dev-internal-key-change-in-production"}
|
||||
headers={"Authorization": f"Bearer {service_token}"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
session_data = response.json()
|
||||
@@ -161,22 +164,27 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# ✅ STEP 4: Verify tenant access if this is a tenant-scoped route
|
||||
if tenant_id and is_tenant_scoped_path(request.url.path):
|
||||
# Use TenantAccessManager for gateway-level verification with caching
|
||||
if self.redis_client and tenant_access_manager.redis_client is None:
|
||||
tenant_access_manager.redis_client = self.redis_client
|
||||
# Skip tenant access verification for service tokens (services have admin access)
|
||||
if user_context.get("type") != "service":
|
||||
# Use TenantAccessManager for gateway-level verification with caching
|
||||
if self.redis_client and tenant_access_manager.redis_client is None:
|
||||
tenant_access_manager.redis_client = self.redis_client
|
||||
|
||||
has_access = await tenant_access_manager.verify_basic_tenant_access(
|
||||
user_context["user_id"],
|
||||
tenant_id
|
||||
)
|
||||
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": f"Access denied to tenant {tenant_id}"}
|
||||
has_access = await tenant_access_manager.verify_basic_tenant_access(
|
||||
user_context["user_id"],
|
||||
tenant_id
|
||||
)
|
||||
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": f"Access denied to tenant {tenant_id}"}
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Service token granted access to tenant {tenant_id}",
|
||||
service=user_context.get("service"))
|
||||
|
||||
# Get tenant subscription tier and inject into user context
|
||||
# NEW: Use JWT data if available, skip HTTP call
|
||||
if user_context.get("subscription_from_jwt"):
|
||||
@@ -365,6 +373,12 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
except Exception as e:
|
||||
logger.warning("Token freshness check setup failed", error=str(e))
|
||||
|
||||
# FIX: Validate service tokens with tenant context for tenant-scoped routes
|
||||
if token_type == "service" and payload.get("tenant_id"):
|
||||
# Service tokens with tenant context are valid for tenant-scoped operations
|
||||
logger.debug("Service token with tenant context validated",
|
||||
service=payload.get("service"), tenant_id=payload.get("tenant_id"))
|
||||
|
||||
return True
|
||||
|
||||
def _validate_jwt_integrity(self, payload: Dict[str, Any]) -> bool:
|
||||
@@ -469,7 +483,13 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
base_context["role"] = "admin"
|
||||
base_context["user_id"] = f"{service_name}-service"
|
||||
base_context["email"] = f"{service_name}-service@internal"
|
||||
logger.debug(f"Service authentication: {payload['service']}")
|
||||
|
||||
# FIX: Service tokens with tenant context should use that tenant_id
|
||||
if payload.get("tenant_id"):
|
||||
base_context["tenant_id"] = payload["tenant_id"]
|
||||
logger.debug(f"Service authentication with tenant context: {service_name}, tenant_id: {payload['tenant_id']}")
|
||||
else:
|
||||
logger.debug(f"Service authentication: {service_name}")
|
||||
|
||||
return base_context
|
||||
|
||||
@@ -556,18 +576,30 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
Inject user and tenant context headers for downstream services
|
||||
ENHANCED: Added logging to verify header injection
|
||||
"""
|
||||
# Log what we're injecting for debugging
|
||||
logger.debug(
|
||||
"Injecting context headers",
|
||||
# Enhanced logging for debugging
|
||||
logger.info(
|
||||
"🔧 Injecting context headers",
|
||||
user_id=user_context.get("user_id"),
|
||||
user_type=user_context.get("type", ""),
|
||||
service_name=user_context.get("service", ""),
|
||||
role=user_context.get("role", ""),
|
||||
tenant_id=tenant_id,
|
||||
is_demo=user_context.get("is_demo", False),
|
||||
demo_session_id=user_context.get("demo_session_id", ""),
|
||||
path=request.url.path
|
||||
)
|
||||
|
||||
# Add user context headers
|
||||
logger.debug(f"DEBUG: Injecting headers for user: {user_context.get('user_id')}, is_demo: {user_context.get('is_demo', False)}")
|
||||
logger.debug(f"DEBUG: request.headers object id: {id(request.headers)}, _list id: {id(request.headers.__dict__.get('_list', []))}")
|
||||
|
||||
# Store headers in request.state for cross-middleware access
|
||||
request.state.injected_headers = {
|
||||
"x-user-id": user_context["user_id"],
|
||||
"x-user-email": user_context["email"],
|
||||
"x-user-role": user_context.get("role", "user")
|
||||
}
|
||||
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-id", user_context["user_id"].encode()
|
||||
))
|
||||
@@ -607,10 +639,17 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Add is_demo flag for demo sessions
|
||||
is_demo = user_context.get("is_demo", False)
|
||||
logger.debug(f"DEBUG: is_demo value: {is_demo}, type: {type(is_demo)}")
|
||||
if is_demo:
|
||||
logger.info(f"🎭 Adding demo session headers",
|
||||
demo_session_id=user_context.get("demo_session_id", ""),
|
||||
demo_account_type=user_context.get("demo_account_type", ""),
|
||||
path=request.url.path)
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-is-demo", b"true"
|
||||
))
|
||||
else:
|
||||
logger.debug(f"DEBUG: Not adding demo headers because is_demo is: {is_demo}")
|
||||
|
||||
# Add demo session context headers for backend services
|
||||
demo_session_id = user_context.get("demo_session_id", "")
|
||||
|
||||
@@ -304,14 +304,27 @@ class DemoMiddleware(BaseHTTPMiddleware):
|
||||
return response
|
||||
|
||||
async def _get_session_info(self, session_id: str) -> Optional[dict]:
|
||||
"""Get session information from demo service"""
|
||||
"""Get session information from demo service using JWT service token"""
|
||||
try:
|
||||
# Create JWT service token for gateway-to-demo-session communication
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from app.core.config import settings
|
||||
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
service_token = jwt_handler.create_service_token(service_name="gateway")
|
||||
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.demo_session_url}/api/v1/demo/sessions/{session_id}"
|
||||
f"{self.demo_session_url}/api/v1/demo/sessions/{session_id}",
|
||||
headers={"Authorization": f"Bearer {service_token}"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.warning("Demo session fetch failed",
|
||||
session_id=session_id,
|
||||
status_code=response.status_code,
|
||||
response_text=response.text[:200] if hasattr(response, 'text') else '')
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get session info", session_id=session_id, error=str(e))
|
||||
|
||||
@@ -63,7 +63,9 @@ class AuthProxy:
|
||||
target_url = f"{auth_url}/{path}"
|
||||
|
||||
# Prepare headers (remove hop-by-hop headers)
|
||||
headers = self._prepare_headers(dict(request.headers))
|
||||
# IMPORTANT: Use request.headers directly to get headers added by middleware
|
||||
# Also check request.state for headers injected by middleware
|
||||
headers = self._prepare_headers(request.headers, request)
|
||||
|
||||
# Get request body
|
||||
body = await request.body()
|
||||
@@ -133,7 +135,7 @@ class AuthProxy:
|
||||
# Fall back to configured URL
|
||||
return AUTH_SERVICE_URL
|
||||
|
||||
def _prepare_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
|
||||
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
|
||||
# Remove hop-by-hop headers
|
||||
hop_by_hop_headers = {
|
||||
@@ -141,10 +143,94 @@ class AuthProxy:
|
||||
'proxy-authorization', 'te', 'trailers', 'upgrade'
|
||||
}
|
||||
|
||||
filtered_headers = {
|
||||
k: v for k, v in headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
# Convert headers to dict - get ALL headers including those added by middleware
|
||||
# Middleware adds headers to _list, so we need to read from there
|
||||
logger.debug(f"DEBUG: headers type: {type(headers)}, has _list: {hasattr(headers, '_list')}, has raw: {hasattr(headers, 'raw')}")
|
||||
logger.debug(f"DEBUG: headers.__dict__ keys: {list(headers.__dict__.keys())}")
|
||||
logger.debug(f"DEBUG: '_list' in headers.__dict__: {'_list' in headers.__dict__}")
|
||||
|
||||
if hasattr(headers, '_list'):
|
||||
logger.debug(f"DEBUG: Entering _list branch")
|
||||
logger.debug(f"DEBUG: headers object id: {id(headers)}, _list id: {id(headers.__dict__.get('_list', []))}")
|
||||
# Get headers from the _list where middleware adds them
|
||||
all_headers_list = headers.__dict__.get('_list', [])
|
||||
logger.debug(f"DEBUG: _list length: {len(all_headers_list)}")
|
||||
|
||||
# Debug: Show first few headers in the list
|
||||
debug_headers = []
|
||||
for i, (k, v) in enumerate(all_headers_list):
|
||||
if i < 5: # Show first 5 headers for debugging
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
debug_headers.append(f"{key}: {value}")
|
||||
logger.debug(f"DEBUG: First headers in _list: {debug_headers}")
|
||||
|
||||
# Convert to dict for easier processing
|
||||
all_headers = {}
|
||||
for k, v in all_headers_list:
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
|
||||
# Debug: Show if x-user-id and x-is-demo are in the dict
|
||||
logger.debug(f"DEBUG: x-user-id in all_headers: {'x-user-id' in all_headers}, x-is-demo in all_headers: {'x-is-demo' in all_headers}")
|
||||
logger.debug(f"DEBUG: all_headers keys: {list(all_headers.keys())[:10]}...") # Show first 10 keys
|
||||
|
||||
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
|
||||
|
||||
# Check if headers are missing and try to get them from request.state
|
||||
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
|
||||
logger.debug(f"DEBUG: Found injected_headers in request.state: {request.state.injected_headers}")
|
||||
# Add missing headers from request.state
|
||||
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
|
||||
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
|
||||
logger.debug(f"DEBUG: Added x-user-id from request.state: {all_headers['x-user-id']}")
|
||||
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
|
||||
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
|
||||
logger.debug(f"DEBUG: Added x-user-email from request.state: {all_headers['x-user-email']}")
|
||||
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
|
||||
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
|
||||
logger.debug(f"DEBUG: Added x-user-role from request.state: {all_headers['x-user-role']}")
|
||||
|
||||
# Add is_demo flag if this is a demo session
|
||||
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
|
||||
all_headers['x-is-demo'] = 'true'
|
||||
logger.debug(f"DEBUG: Added x-is-demo from request.state.is_demo_session")
|
||||
|
||||
# Filter out hop-by-hop headers
|
||||
filtered_headers = {
|
||||
k: v for k, v in all_headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
elif hasattr(headers, 'raw'):
|
||||
logger.debug(f"DEBUG: Entering raw branch")
|
||||
|
||||
# Filter out hop-by-hop headers
|
||||
filtered_headers = {
|
||||
k: v for k, v in all_headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
elif hasattr(headers, 'raw'):
|
||||
# Fallback to raw headers if _list not available
|
||||
all_headers = {
|
||||
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
|
||||
for k, v in headers.raw
|
||||
}
|
||||
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
|
||||
|
||||
filtered_headers = {
|
||||
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
|
||||
for k, v in headers.raw
|
||||
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
|
||||
}
|
||||
else:
|
||||
# Handle case where headers is already a dict
|
||||
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {headers.get('x-user-id', 'MISSING')}, x_is_demo: {headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {headers.get('x-demo-session-id', 'MISSING')}, headers: {list(headers.keys())}")
|
||||
|
||||
filtered_headers = {
|
||||
k: v for k, v in headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
|
||||
# Add gateway identifier
|
||||
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
|
||||
|
||||
@@ -110,16 +110,16 @@ async def _proxy_request(request: Request, target_path: str, service_url: str):
|
||||
headers["x-user-role"] = str(user.get('role', 'user'))
|
||||
headers["x-user-full-name"] = str(user.get('full_name', ''))
|
||||
headers["x-tenant-id"] = str(user.get('tenant_id', ''))
|
||||
|
||||
|
||||
# Add subscription context headers
|
||||
if user.get('subscription_tier'):
|
||||
headers["x-subscription-tier"] = str(user.get('subscription_tier', ''))
|
||||
logger.debug(f"Forwarding subscription tier: {user.get('subscription_tier')}")
|
||||
|
||||
|
||||
if user.get('subscription_status'):
|
||||
headers["x-subscription-status"] = str(user.get('subscription_status', ''))
|
||||
logger.debug(f"Forwarding subscription status: {user.get('subscription_status')}")
|
||||
|
||||
|
||||
logger.info(f"Forwarding subscription request to {url} with user context: user_id={user.get('user_id')}, email={user.get('email')}, subscription_tier={user.get('subscription_tier', 'not_set')}")
|
||||
else:
|
||||
logger.warning(f"No user context available when forwarding subscription request to {url}")
|
||||
|
||||
@@ -714,15 +714,15 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
|
||||
|
||||
try:
|
||||
url = f"{service_url}{target_path}"
|
||||
|
||||
|
||||
# Forward headers and add user/tenant context
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
|
||||
|
||||
# Add tenant ID header if provided
|
||||
if tenant_id:
|
||||
headers["X-Tenant-ID"] = tenant_id
|
||||
|
||||
|
||||
# Add user context headers if available
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
user = request.state.user
|
||||
@@ -731,16 +731,16 @@ async def _proxy_request(request: Request, target_path: str, service_url: str, t
|
||||
headers["x-user-role"] = str(user.get('role', 'user'))
|
||||
headers["x-user-full-name"] = str(user.get('full_name', ''))
|
||||
headers["x-tenant-id"] = tenant_id or str(user.get('tenant_id', ''))
|
||||
|
||||
|
||||
# Add subscription context headers
|
||||
if user.get('subscription_tier'):
|
||||
headers["x-subscription-tier"] = str(user.get('subscription_tier', ''))
|
||||
logger.debug(f"Forwarding subscription tier: {user.get('subscription_tier')}")
|
||||
|
||||
|
||||
if user.get('subscription_status'):
|
||||
headers["x-subscription-status"] = str(user.get('subscription_status', ''))
|
||||
logger.debug(f"Forwarding subscription status: {user.get('subscription_status')}")
|
||||
|
||||
|
||||
# Debug logging
|
||||
logger.info(f"Forwarding request to {url} with user context: user_id={user.get('user_id')}, email={user.get('email')}, tenant_id={tenant_id}, subscription_tier={user.get('subscription_tier', 'not_set')}")
|
||||
else:
|
||||
|
||||
@@ -63,7 +63,9 @@ class UserProxy:
|
||||
target_url = f"{auth_url}/api/v1/auth/{path}"
|
||||
|
||||
# Prepare headers (remove hop-by-hop headers)
|
||||
headers = self._prepare_headers(dict(request.headers))
|
||||
# IMPORTANT: Use request.headers directly to get headers added by middleware
|
||||
# Also check request.state for headers injected by middleware
|
||||
headers = self._prepare_headers(request.headers, request)
|
||||
|
||||
# Get request body
|
||||
body = await request.body()
|
||||
@@ -133,23 +135,64 @@ class UserProxy:
|
||||
# Fall back to configured URL
|
||||
return AUTH_SERVICE_URL
|
||||
|
||||
def _prepare_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
|
||||
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
|
||||
# Remove hop-by-hop headers
|
||||
hop_by_hop_headers = {
|
||||
'connection', 'keep-alive', 'proxy-authenticate',
|
||||
'proxy-authorization', 'te', 'trailers', 'upgrade'
|
||||
}
|
||||
|
||||
filtered_headers = {
|
||||
k: v for k, v in headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
|
||||
|
||||
# Convert headers to dict if it's a Headers object
|
||||
# This ensures we get ALL headers including those added by middleware
|
||||
if hasattr(headers, '_list'):
|
||||
# Get headers from the _list where middleware adds them
|
||||
all_headers_list = headers.__dict__.get('_list', [])
|
||||
|
||||
# Convert to dict for easier processing
|
||||
all_headers = {}
|
||||
for k, v in all_headers_list:
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
|
||||
# Check if headers are missing and try to get them from request.state
|
||||
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
|
||||
# Add missing headers from request.state
|
||||
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
|
||||
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
|
||||
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
|
||||
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
|
||||
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
|
||||
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
|
||||
|
||||
# Add is_demo flag if this is a demo session
|
||||
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
|
||||
all_headers['x-is-demo'] = 'true'
|
||||
|
||||
# Filter out hop-by-hop headers
|
||||
filtered_headers = {
|
||||
k: v for k, v in all_headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
elif hasattr(headers, 'raw'):
|
||||
# FastAPI/Starlette Headers object - use raw to get all headers
|
||||
filtered_headers = {
|
||||
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
|
||||
for k, v in headers.raw
|
||||
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
|
||||
}
|
||||
else:
|
||||
# Already a dict
|
||||
filtered_headers = {
|
||||
k: v for k, v in headers.items()
|
||||
if k.lower() not in hop_by_hop_headers
|
||||
}
|
||||
|
||||
# Add gateway identifier
|
||||
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
|
||||
filtered_headers['X-Gateway-Version'] = '1.0.0'
|
||||
|
||||
|
||||
return filtered_headers
|
||||
|
||||
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||
|
||||
Reference in New Issue
Block a user