Fix gateway
This commit is contained in:
@@ -3,8 +3,10 @@ Authentication middleware for gateway
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import Request, HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
import httpx
|
||||
from typing import Optional
|
||||
|
||||
@@ -28,74 +30,77 @@ PUBLIC_ROUTES = [
|
||||
"/api/v1/auth/refresh"
|
||||
]
|
||||
|
||||
async def auth_middleware(request: Request, call_next):
|
||||
"""Authentication middleware"""
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""Authentication middleware class"""
|
||||
|
||||
# Check if route requires authentication
|
||||
if _is_public_route(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Get token from header
|
||||
token = _extract_token(request)
|
||||
if not token:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Authentication required"}
|
||||
)
|
||||
|
||||
# Verify token
|
||||
try:
|
||||
# First try to verify token locally
|
||||
payload = jwt_handler.verify_token(token)
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with authentication"""
|
||||
|
||||
if payload:
|
||||
# Add user info to request state
|
||||
request.state.user = payload
|
||||
# Check if route requires authentication
|
||||
if self._is_public_route(request.url.path):
|
||||
return await call_next(request)
|
||||
else:
|
||||
# Token invalid or expired, verify with auth service
|
||||
user_info = await _verify_with_auth_service(token)
|
||||
if user_info:
|
||||
request.state.user = user_info
|
||||
|
||||
# Get token from header
|
||||
token = self._extract_token(request)
|
||||
if not token:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Authentication required"}
|
||||
)
|
||||
|
||||
# Verify token
|
||||
try:
|
||||
# First try to verify token locally
|
||||
payload = jwt_handler.verify_token(token)
|
||||
|
||||
if payload:
|
||||
# Add user info to request state
|
||||
request.state.user = payload
|
||||
return await call_next(request)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid or expired token"}
|
||||
# Token invalid or expired, verify with auth service
|
||||
user_info = await self._verify_with_auth_service(token)
|
||||
if user_info:
|
||||
request.state.user = user_info
|
||||
return await call_next(request)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid or expired token"}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Authentication failed"}
|
||||
)
|
||||
|
||||
def _is_public_route(self, path: str) -> bool:
|
||||
"""Check if route is public"""
|
||||
return any(path.startswith(route) for route in PUBLIC_ROUTES)
|
||||
|
||||
def _extract_token(self, request: Request) -> Optional[str]:
|
||||
"""Extract JWT token from request"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header.split(" ")[1]
|
||||
return None
|
||||
|
||||
async def _verify_with_auth_service(self, token: str) -> Optional[dict]:
|
||||
"""Verify token with auth service"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.post(
|
||||
f"{settings.AUTH_SERVICE_URL}/verify",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Authentication failed"}
|
||||
)
|
||||
|
||||
def _is_public_route(path: str) -> bool:
|
||||
"""Check if route is public"""
|
||||
return any(path.startswith(route) for route in PUBLIC_ROUTES)
|
||||
|
||||
def _extract_token(request: Request) -> Optional[str]:
|
||||
"""Extract JWT token from request"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header.split(" ")[1]
|
||||
return None
|
||||
|
||||
async def _verify_with_auth_service(token: str) -> Optional[dict]:
|
||||
"""Verify token with auth service"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.post(
|
||||
f"{settings.AUTH_SERVICE_URL}/verify",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Auth service verification failed: {e}")
|
||||
return None
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Auth service verification failed: {e}")
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user