REFACTOR API gateway fix 3
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# gateway/app/middleware/auth.py
|
||||
"""
|
||||
Enhanced Authentication Middleware for API Gateway with Tenant Access Control
|
||||
FIXED VERSION - Proper JWT verification and token structure handling
|
||||
"""
|
||||
|
||||
import structlog
|
||||
@@ -9,6 +10,8 @@ from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from typing import Optional, Dict, Any
|
||||
import httpx
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
@@ -16,7 +19,7 @@ from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_f
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# JWT handler for local token validation
|
||||
# JWT handler for local token validation - using SAME configuration as auth service
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
# Routes that don't require authentication
|
||||
@@ -63,13 +66,12 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
|
||||
# ✅ STEP 2: Verify token and get user context
|
||||
# Pass self.redis_client to _verify_token to enable caching
|
||||
user_context = await self._verify_token(token)
|
||||
if not user_context:
|
||||
logger.warning(f"Invalid token for route: {request.url.path}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid or expired token"}
|
||||
content={"detail": "User not authenticated"}
|
||||
)
|
||||
|
||||
# ✅ STEP 3: Extract tenant context from URL using shared utility
|
||||
@@ -78,11 +80,10 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
# ✅ STEP 4: Verify tenant access if this is a tenant-scoped route
|
||||
if tenant_id and is_tenant_scoped_path(request.url.path):
|
||||
# Use TenantAccessManager for gateway-level verification with caching
|
||||
# Ensure tenant_access_manager uses the redis_client from the middleware
|
||||
if self.redis_client and tenant_access_manager.redis_client is None:
|
||||
tenant_access_manager.redis_client = self.redis_client
|
||||
|
||||
has_access = await tenant_access_manager.verify_basic_tenant_access( # Corrected method call
|
||||
has_access = await tenant_access_manager.verify_basic_tenant_access(
|
||||
user_context["user_id"],
|
||||
tenant_id
|
||||
)
|
||||
@@ -129,18 +130,22 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
return None
|
||||
|
||||
async def _verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Verify JWT token with fallback strategy"""
|
||||
"""
|
||||
Verify JWT token with improved fallback strategy
|
||||
FIXED: Better error handling and token structure validation
|
||||
"""
|
||||
|
||||
# Try local JWT validation first (fast)
|
||||
# Strategy 1: Try local JWT validation first (fast)
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if payload and self._validate_token_payload(payload):
|
||||
logger.debug("Token validated locally")
|
||||
return payload
|
||||
# Convert JWT payload to user context format
|
||||
return self._jwt_payload_to_user_context(payload)
|
||||
except Exception as e:
|
||||
logger.debug(f"Local token validation failed: {e}")
|
||||
|
||||
# Check cache for recently validated tokens
|
||||
# Strategy 2: Check cache for recently validated tokens
|
||||
if self.redis_client:
|
||||
try:
|
||||
cached_user = await self._get_cached_user(token)
|
||||
@@ -150,7 +155,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache lookup failed: {e}")
|
||||
|
||||
# Verify with auth service (authoritative)
|
||||
# Strategy 3: Verify with auth service (authoritative)
|
||||
try:
|
||||
user_context = await self._verify_with_auth_service(token)
|
||||
if user_context:
|
||||
@@ -165,92 +170,134 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
return None
|
||||
|
||||
def _validate_token_payload(self, payload: Dict[str, Any]) -> bool:
|
||||
"""Validate JWT payload has required fields"""
|
||||
required_fields = ["user_id", "email", "exp"]
|
||||
return all(field in payload for field in required_fields)
|
||||
"""
|
||||
Validate JWT payload has required fields
|
||||
FIXED: Updated to match actual token structure from auth service
|
||||
"""
|
||||
required_fields = ["user_id", "email", "exp", "type"]
|
||||
missing_fields = [field for field in required_fields if field not in payload]
|
||||
|
||||
if missing_fields:
|
||||
logger.warning(f"Token payload missing fields: {missing_fields}")
|
||||
return False
|
||||
|
||||
# Validate token type
|
||||
if payload.get("type") != "access":
|
||||
logger.warning(f"Invalid token type: {payload.get('type')}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _jwt_payload_to_user_context(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert JWT payload to user context format
|
||||
FIXED: Proper mapping between JWT structure and user context
|
||||
"""
|
||||
return {
|
||||
"user_id": payload["user_id"],
|
||||
"email": payload["email"],
|
||||
"exp": payload["exp"],
|
||||
"valid": True
|
||||
}
|
||||
|
||||
async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Verify token with auth service"""
|
||||
"""
|
||||
Verify token with auth service
|
||||
FIXED: Improved error handling and response parsing
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.post(
|
||||
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
auth_response = response.json()
|
||||
|
||||
# Validate auth service response structure
|
||||
if auth_response.get("valid") and auth_response.get("user_id"):
|
||||
return {
|
||||
"user_id": auth_response["user_id"],
|
||||
"email": auth_response["email"],
|
||||
"exp": auth_response.get("exp"),
|
||||
"valid": True
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Auth service returned invalid response: {auth_response}")
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"Auth service returned {response.status_code}")
|
||||
logger.warning(f"Auth service returned {response.status_code}: {response.text}")
|
||||
return None
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Auth service timeout during token verification")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Auth service error: {e}")
|
||||
return None
|
||||
|
||||
async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user context from cache"""
|
||||
"""
|
||||
Get user context from cache
|
||||
FIXED: Better error handling and JSON parsing
|
||||
"""
|
||||
if not self.redis_client:
|
||||
return None
|
||||
|
||||
cache_key = f"auth:token:{hash(token)}"
|
||||
cache_key = f"auth:token:{hash(token) % 1000000}" # Use modulo for shorter keys
|
||||
try:
|
||||
cached_data = await self.redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
import json
|
||||
if isinstance(cached_data, bytes):
|
||||
cached_data = cached_data.decode()
|
||||
return json.loads(cached_data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse cached user data: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache get failed: {e}")
|
||||
logger.warning(f"Cache lookup error: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _cache_user(self, token: str, user_context: Dict[str, Any], ttl: int = 300):
|
||||
"""Cache user context for 5 minutes"""
|
||||
async def _cache_user(self, token: str, user_context: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Cache user context
|
||||
FIXED: Better error handling and expiration
|
||||
"""
|
||||
if not self.redis_client:
|
||||
return
|
||||
|
||||
cache_key = f"auth:token:{hash(token)}"
|
||||
cache_key = f"auth:token:{hash(token) % 1000000}"
|
||||
try:
|
||||
import json
|
||||
await self.redis_client.setex(cache_key, ttl, json.dumps(user_context))
|
||||
# Cache for 5 minutes (shorter than token expiry)
|
||||
await self.redis_client.setex(
|
||||
cache_key,
|
||||
300, # 5 minutes
|
||||
json.dumps(user_context)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache set failed: {e}")
|
||||
logger.warning(f"Failed to cache user context: {e}")
|
||||
|
||||
def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str]):
|
||||
"""Inject authentication and tenant headers for downstream services"""
|
||||
def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
|
||||
"""
|
||||
Inject user and tenant context headers for downstream services
|
||||
FIXED: Proper header injection
|
||||
"""
|
||||
# Add user context headers
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-id", user_context["user_id"].encode()
|
||||
))
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-user-email", user_context["email"].encode()
|
||||
))
|
||||
|
||||
# Remove any existing auth headers to prevent spoofing
|
||||
headers_to_remove = [
|
||||
"x-user-id", "x-user-email", "x-user-role",
|
||||
"x-tenant-id", "x-tenant-verified", "x-authenticated"
|
||||
]
|
||||
|
||||
for header in headers_to_remove:
|
||||
request.headers.__dict__["_list"] = [
|
||||
(k, v) for k, v in request.headers.raw
|
||||
if k.lower() != header.lower()
|
||||
]
|
||||
|
||||
# Inject new headers
|
||||
new_headers = [
|
||||
(b"x-authenticated", b"true"),
|
||||
(b"x-user-id", str(user_context.get("user_id", "")).encode()),
|
||||
(b"x-user-email", str(user_context.get("email", "")).encode()),
|
||||
(b"x-user-role", str(user_context.get("role", "user")).encode()),
|
||||
]
|
||||
|
||||
# Add tenant context if verified
|
||||
# Add tenant context if available
|
||||
if tenant_id:
|
||||
new_headers.extend([
|
||||
(b"x-tenant-id", tenant_id.encode()),
|
||||
(b"x-tenant-verified", b"true")
|
||||
])
|
||||
|
||||
# Add headers to request
|
||||
request.headers.__dict__["_list"].extend(new_headers)
|
||||
|
||||
logger.debug(f"Injected context headers",
|
||||
user_id=user_context.get("user_id"),
|
||||
tenant_id=tenant_id)
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-tenant-id", tenant_id.encode()
|
||||
))
|
||||
|
||||
# Add gateway identification
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-forwarded-by", b"bakery-gateway"
|
||||
))
|
||||
Reference in New Issue
Block a user