New token arch
This commit is contained in:
@@ -5,7 +5,7 @@ FIXED VERSION - Proper JWT verification and token structure handling
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import Request, HTTPException
|
||||
from fastapi import Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
@@ -60,6 +60,16 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
|
||||
# SECURITY: Remove any incoming x-subscription-* headers
|
||||
# These will be re-injected from verified JWT only
|
||||
sanitized_headers = [
|
||||
(k, v) for k, v in request.headers.raw
|
||||
if not k.decode().lower().startswith('x-subscription-')
|
||||
and not k.decode().lower().startswith('x-user-')
|
||||
and not k.decode().lower().startswith('x-tenant-')
|
||||
]
|
||||
request.headers.__dict__["_list"] = sanitized_headers
|
||||
|
||||
# Skip authentication for public routes
|
||||
if self._is_public_route(request.url.path):
|
||||
return await call_next(request)
|
||||
@@ -168,7 +178,14 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
|
||||
# Get tenant subscription tier and inject into user context
|
||||
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
|
||||
# NEW: Use JWT data if available, skip HTTP call
|
||||
if user_context.get("subscription_from_jwt"):
|
||||
subscription_tier = user_context.get("subscription_tier")
|
||||
logger.debug("Using subscription tier from JWT", tier=subscription_tier)
|
||||
else:
|
||||
# Only for old tokens - remove after full rollout
|
||||
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
|
||||
|
||||
if subscription_tier:
|
||||
user_context["subscription_tier"] = subscription_tier
|
||||
|
||||
@@ -255,6 +272,22 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
if payload and self._validate_token_payload(payload):
|
||||
logger.debug("Token validated locally")
|
||||
|
||||
# NEW: Check token freshness for subscription changes (async)
|
||||
if payload.get("tenant_id") and request:
|
||||
try:
|
||||
is_fresh = await self._verify_token_freshness(payload, payload["tenant_id"])
|
||||
if not is_fresh:
|
||||
logger.warning("Stale token detected - subscription changed since token was issued",
|
||||
user_id=payload.get("user_id"),
|
||||
tenant_id=payload.get("tenant_id"))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token is stale - subscription has changed"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Token freshness check failed, allowing token", error=str(e))
|
||||
# Allow token if check fails (fail open for availability)
|
||||
|
||||
# Check if token is near expiry and set flag for response header
|
||||
if request:
|
||||
import time
|
||||
@@ -321,6 +354,78 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
if time_until_expiry < 300: # 5 minutes
|
||||
logger.warning(f"Token expires in {int(time_until_expiry)} seconds for user {payload.get('email')}")
|
||||
|
||||
# NEW: Check token freshness for subscription changes
|
||||
if payload.get("tenant_id"):
|
||||
try:
|
||||
# Note: We can't await here because this is a sync function
|
||||
# Token freshness will be checked in the async dispatch method
|
||||
# For now, just log that we would check freshness
|
||||
logger.debug("Token freshness check would be performed in async context",
|
||||
tenant_id=payload.get("tenant_id"))
|
||||
except Exception as e:
|
||||
logger.warning("Token freshness check setup failed", error=str(e))
|
||||
|
||||
return True
|
||||
|
||||
def _validate_jwt_integrity(self, payload: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate JWT payload integrity beyond signature verification.
|
||||
Prevents edge cases where payload might be malformed.
|
||||
"""
|
||||
# Required fields must exist
|
||||
required_fields = ["user_id", "email", "exp", "iat", "iss"]
|
||||
if not all(field in payload for field in required_fields):
|
||||
logger.warning("JWT missing required fields", missing=[f for f in required_fields if f not in payload])
|
||||
return False
|
||||
|
||||
# Issuer must be our auth service
|
||||
if payload.get("iss") != "bakery-auth":
|
||||
logger.warning("JWT has invalid issuer", issuer=payload.get("iss"))
|
||||
return False
|
||||
|
||||
# Token type must be valid
|
||||
if payload.get("type") not in ["access", "service"]:
|
||||
logger.warning("JWT has invalid type", token_type=payload.get("type"))
|
||||
return False
|
||||
|
||||
# Subscription tier must be valid if present
|
||||
valid_tiers = ["starter", "professional", "enterprise"]
|
||||
if payload.get("subscription"):
|
||||
tier = payload["subscription"].get("tier", "").lower()
|
||||
if tier and tier not in valid_tiers:
|
||||
logger.warning("JWT has invalid subscription tier", tier=tier)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _verify_token_freshness(self, payload: Dict[str, Any], tenant_id: str) -> bool:
|
||||
"""
|
||||
Verify token was issued after the last subscription change.
|
||||
Prevents use of stale tokens with old subscription data.
|
||||
"""
|
||||
if not self.redis_client:
|
||||
return True # Skip check if no Redis
|
||||
|
||||
try:
|
||||
subscription_changed_at = await self.redis_client.get(
|
||||
f"tenant:{tenant_id}:subscription_changed_at"
|
||||
)
|
||||
|
||||
if subscription_changed_at:
|
||||
changed_timestamp = float(subscription_changed_at)
|
||||
token_issued_at = payload.get("iat", 0)
|
||||
|
||||
if token_issued_at < changed_timestamp:
|
||||
logger.warning(
|
||||
"Token issued before subscription change",
|
||||
token_iat=token_issued_at,
|
||||
subscription_changed=changed_timestamp,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return False # Token is stale
|
||||
except Exception as e:
|
||||
logger.warning("Failed to check token freshness", error=str(e))
|
||||
|
||||
return True
|
||||
|
||||
def _jwt_payload_to_user_context(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -328,6 +433,13 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
Convert JWT payload to user context format
|
||||
FIXED: Proper mapping between JWT structure and user context
|
||||
"""
|
||||
# NEW: Validate JWT integrity before processing
|
||||
if not self._validate_jwt_integrity(payload):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid JWT payload"
|
||||
)
|
||||
|
||||
base_context = {
|
||||
"user_id": payload["user_id"],
|
||||
"email": payload["email"],
|
||||
@@ -336,6 +448,20 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"role": payload.get("role", "user"),
|
||||
}
|
||||
|
||||
# NEW: Extract subscription from JWT
|
||||
if payload.get("tenant_id"):
|
||||
base_context["tenant_id"] = payload["tenant_id"]
|
||||
base_context["tenant_role"] = payload.get("tenant_role", "member")
|
||||
|
||||
if payload.get("subscription"):
|
||||
sub = payload["subscription"]
|
||||
base_context["subscription_tier"] = sub.get("tier", "starter")
|
||||
base_context["subscription_status"] = sub.get("status", "active")
|
||||
base_context["subscription_from_jwt"] = True # Flag to skip HTTP
|
||||
|
||||
if payload.get("tenant_access"):
|
||||
base_context["tenant_access"] = payload["tenant_access"]
|
||||
|
||||
if payload.get("service"):
|
||||
service_name = payload["service"]
|
||||
base_context["service"] = service_name
|
||||
@@ -571,4 +697,4 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tenant subscription tier: {e}")
|
||||
return "starter" # Default to starter on error
|
||||
return "starter" # Default to starter on error
|
||||
|
||||
Reference in New Issue
Block a user