Improve auth flow

This commit is contained in:
Urtzi Alfaro
2025-07-19 17:49:03 +02:00
parent f3071c00bd
commit abc8b68ab4
16 changed files with 1437 additions and 572 deletions

View File

@@ -1,41 +1,76 @@
# shared/auth/decorators.py - NEW FILE
"""
Authentication decorators for FastAPI
Authentication decorators for microservices
"""
from functools import wraps
from fastapi import HTTPException, Depends
from fastapi.security import HTTPBearer
import httpx
import logging
from fastapi import HTTPException, status, Request
from typing import Callable, Optional
logger = logging.getLogger(__name__)
security = HTTPBearer()
def verify_service_token(auth_service_url: str):
"""Verify service token with auth service"""
def require_authentication(func: Callable) -> Callable:
"""Decorator to require authentication - assumes gateway has validated token"""
async def verify_token(token: str = Depends(security)):
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{auth_service_url}/verify",
headers={"Authorization": f"Bearer {token.credentials}"}
)
if response.status_code == 200:
return response.json()
else:
raise HTTPException(
status_code=401,
detail="Invalid authentication credentials"
)
except httpx.RequestError as e:
logger.error(f"Auth service unavailable: {e}")
@wraps(func)
async def wrapper(*args, **kwargs):
# Find request object in arguments
request = None
for arg in args:
if isinstance(arg, Request):
request = arg
break
if not request:
# Check kwargs
request = kwargs.get('request')
if not request:
raise HTTPException(
status_code=503,
detail="Authentication service unavailable"
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Request object not found"
)
# Check if user context exists (set by gateway)
if not hasattr(request.state, 'user') or not request.state.user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required"
)
return await func(*args, **kwargs)
return verify_token
return wrapper
def require_tenant_access(func: Callable) -> Callable:
"""Decorator to require tenant access"""
@wraps(func)
async def wrapper(*args, **kwargs):
# Find request object
request = None
for arg in args:
if isinstance(arg, Request):
request = arg
break
if not request or not hasattr(request.state, 'tenant_id'):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Tenant access required"
)
return await func(*args, **kwargs)
return wrapper
def get_current_user(request: Request) -> dict:
"""Get current user from request state"""
if not hasattr(request.state, 'user'):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not authenticated"
)
return request.state.user
def get_current_tenant_id(request: Request) -> Optional[str]:
"""Get current tenant ID from request state"""
return getattr(request.state, 'tenant_id', None)

View File

@@ -1,58 +1,97 @@
# shared/auth/jwt_handler.py - IMPROVED VERSION
"""
Shared JWT Authentication Handler
Used across all microservices for consistent authentication
Enhanced JWT Handler with proper token structure
"""
from jose import jwt
from jose import jwt, JWTError
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any
import logging
import structlog
logger = logging.getLogger(__name__)
logger = structlog.get_logger()
class JWTHandler:
"""JWT token handling for microservices"""
"""Enhanced JWT token handling"""
def __init__(self, secret_key: str, algorithm: str = "HS256"):
self.secret_key = secret_key
self.algorithm = algorithm
def create_access_token(self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""Create JWT access token"""
to_encode = data.copy()
def create_access_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""
Create JWT access token WITHOUT tenant_id
Tenant context is determined per request, not stored in token
"""
to_encode = {
"sub": user_data["user_id"], # Standard JWT subject
"user_id": user_data["user_id"],
"email": user_data["email"],
"type": "access"
}
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=30)
expire = datetime.now(timezone.utc) + timedelta(minutes=30)
to_encode.update({"exp": expire, "type": "access"})
to_encode.update({
"exp": expire,
"iat": datetime.now(timezone.utc)
})
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
logger.debug(f"Created access token for user {user_data['email']}")
return encoded_jwt
def create_refresh_token(self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
def create_refresh_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""Create JWT refresh token"""
to_encode = data.copy()
to_encode = {
"sub": user_data["user_id"],
"user_id": user_data["user_id"],
"email": user_data["email"],
"type": "refresh"
}
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(days=7)
expire = datetime.now(timezone.utc) + timedelta(days=7)
to_encode.update({"exp": expire, "type": "refresh"})
to_encode.update({
"exp": expire,
"iat": datetime.now(timezone.utc)
})
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
return encoded_jwt
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Verify and decode JWT token"""
"""Verify and decode JWT token with comprehensive validation"""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
# Validate required fields
required_fields = ["user_id", "email", "exp", "type"]
if not all(field in payload for field in required_fields):
logger.warning(f"Token missing required fields: {required_fields}")
return None
# Validate token type
if payload.get("type") not in ["access", "refresh"]:
logger.warning(f"Invalid token type: {payload.get('type')}")
return None
# Check expiration (jose handles this, but double-check)
exp = payload.get("exp")
if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(timezone.utc):
logger.warning("Token has expired")
return None
return payload
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
except JWTError as e:
logger.warning(f"JWT validation failed: {e}")
return None
except jwt.JWTError:
logger.warning("Invalid token")
except Exception as e:
logger.error(f"Unexpected error validating token: {e}")
return None