Improve auth flow
This commit is contained in:
@@ -1,41 +1,76 @@
|
||||
# shared/auth/decorators.py - NEW FILE
|
||||
"""
|
||||
Authentication decorators for FastAPI
|
||||
Authentication decorators for microservices
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from fastapi import HTTPException, Depends
|
||||
from fastapi.security import HTTPBearer
|
||||
import httpx
|
||||
import logging
|
||||
from fastapi import HTTPException, status, Request
|
||||
from typing import Callable, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
def verify_service_token(auth_service_url: str):
|
||||
"""Verify service token with auth service"""
|
||||
def require_authentication(func: Callable) -> Callable:
|
||||
"""Decorator to require authentication - assumes gateway has validated token"""
|
||||
|
||||
async def verify_token(token: str = Depends(security)):
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{auth_service_url}/verify",
|
||||
headers={"Authorization": f"Bearer {token.credentials}"}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid authentication credentials"
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Auth service unavailable: {e}")
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Find request object in arguments
|
||||
request = None
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
break
|
||||
|
||||
if not request:
|
||||
# Check kwargs
|
||||
request = kwargs.get('request')
|
||||
|
||||
if not request:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Authentication service unavailable"
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Request object not found"
|
||||
)
|
||||
|
||||
# Check if user context exists (set by gateway)
|
||||
if not hasattr(request.state, 'user') or not request.state.user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return verify_token
|
||||
return wrapper
|
||||
|
||||
def require_tenant_access(func: Callable) -> Callable:
|
||||
"""Decorator to require tenant access"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Find request object
|
||||
request = None
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
break
|
||||
|
||||
if not request or not hasattr(request.state, 'tenant_id'):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Tenant access required"
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def get_current_user(request: Request) -> dict:
|
||||
"""Get current user from request state"""
|
||||
if not hasattr(request.state, 'user'):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not authenticated"
|
||||
)
|
||||
return request.state.user
|
||||
|
||||
def get_current_tenant_id(request: Request) -> Optional[str]:
|
||||
"""Get current tenant ID from request state"""
|
||||
return getattr(request.state, 'tenant_id', None)
|
||||
|
||||
@@ -1,58 +1,97 @@
|
||||
# shared/auth/jwt_handler.py - IMPROVED VERSION
|
||||
"""
|
||||
Shared JWT Authentication Handler
|
||||
Used across all microservices for consistent authentication
|
||||
Enhanced JWT Handler with proper token structure
|
||||
"""
|
||||
|
||||
from jose import jwt
|
||||
from jose import jwt, JWTError
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
import structlog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class JWTHandler:
|
||||
"""JWT token handling for microservices"""
|
||||
"""Enhanced JWT token handling"""
|
||||
|
||||
def __init__(self, secret_key: str, algorithm: str = "HS256"):
|
||||
self.secret_key = secret_key
|
||||
self.algorithm = algorithm
|
||||
|
||||
def create_access_token(self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create JWT access token"""
|
||||
to_encode = data.copy()
|
||||
def create_access_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""
|
||||
Create JWT access token WITHOUT tenant_id
|
||||
Tenant context is determined per request, not stored in token
|
||||
"""
|
||||
to_encode = {
|
||||
"sub": user_data["user_id"], # Standard JWT subject
|
||||
"user_id": user_data["user_id"],
|
||||
"email": user_data["email"],
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
|
||||
to_encode.update({"exp": expire, "type": "access"})
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc)
|
||||
})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
logger.debug(f"Created access token for user {user_data['email']}")
|
||||
return encoded_jwt
|
||||
|
||||
def create_refresh_token(self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
def create_refresh_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create JWT refresh token"""
|
||||
to_encode = data.copy()
|
||||
to_encode = {
|
||||
"sub": user_data["user_id"],
|
||||
"user_id": user_data["user_id"],
|
||||
"email": user_data["email"],
|
||||
"type": "refresh"
|
||||
}
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=7)
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=7)
|
||||
|
||||
to_encode.update({"exp": expire, "type": "refresh"})
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc)
|
||||
})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Verify and decode JWT token"""
|
||||
"""Verify and decode JWT token with comprehensive validation"""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
|
||||
# Validate required fields
|
||||
required_fields = ["user_id", "email", "exp", "type"]
|
||||
if not all(field in payload for field in required_fields):
|
||||
logger.warning(f"Token missing required fields: {required_fields}")
|
||||
return None
|
||||
|
||||
# Validate token type
|
||||
if payload.get("type") not in ["access", "refresh"]:
|
||||
logger.warning(f"Invalid token type: {payload.get('type')}")
|
||||
return None
|
||||
|
||||
# Check expiration (jose handles this, but double-check)
|
||||
exp = payload.get("exp")
|
||||
if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(timezone.utc):
|
||||
logger.warning("Token has expired")
|
||||
return None
|
||||
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.warning("Token has expired")
|
||||
|
||||
except JWTError as e:
|
||||
logger.warning(f"JWT validation failed: {e}")
|
||||
return None
|
||||
except jwt.JWTError:
|
||||
logger.warning("Invalid token")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error validating token: {e}")
|
||||
return None
|
||||
Reference in New Issue
Block a user