Improve auth flow
This commit is contained in:
@@ -1,58 +1,97 @@
|
||||
# shared/auth/jwt_handler.py - IMPROVED VERSION
|
||||
"""
|
||||
Shared JWT Authentication Handler
|
||||
Used across all microservices for consistent authentication
|
||||
Enhanced JWT Handler with proper token structure
|
||||
"""
|
||||
|
||||
from jose import jwt
|
||||
from jose import jwt, JWTError
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
import structlog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class JWTHandler:
|
||||
"""JWT token handling for microservices"""
|
||||
"""Enhanced JWT token handling"""
|
||||
|
||||
def __init__(self, secret_key: str, algorithm: str = "HS256"):
|
||||
self.secret_key = secret_key
|
||||
self.algorithm = algorithm
|
||||
|
||||
def create_access_token(self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create JWT access token"""
|
||||
to_encode = data.copy()
|
||||
def create_access_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""
|
||||
Create JWT access token WITHOUT tenant_id
|
||||
Tenant context is determined per request, not stored in token
|
||||
"""
|
||||
to_encode = {
|
||||
"sub": user_data["user_id"], # Standard JWT subject
|
||||
"user_id": user_data["user_id"],
|
||||
"email": user_data["email"],
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
|
||||
to_encode.update({"exp": expire, "type": "access"})
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc)
|
||||
})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
logger.debug(f"Created access token for user {user_data['email']}")
|
||||
return encoded_jwt
|
||||
|
||||
def create_refresh_token(self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
def create_refresh_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create JWT refresh token"""
|
||||
to_encode = data.copy()
|
||||
to_encode = {
|
||||
"sub": user_data["user_id"],
|
||||
"user_id": user_data["user_id"],
|
||||
"email": user_data["email"],
|
||||
"type": "refresh"
|
||||
}
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=7)
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=7)
|
||||
|
||||
to_encode.update({"exp": expire, "type": "refresh"})
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc)
|
||||
})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Verify and decode JWT token"""
|
||||
"""Verify and decode JWT token with comprehensive validation"""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
|
||||
# Validate required fields
|
||||
required_fields = ["user_id", "email", "exp", "type"]
|
||||
if not all(field in payload for field in required_fields):
|
||||
logger.warning(f"Token missing required fields: {required_fields}")
|
||||
return None
|
||||
|
||||
# Validate token type
|
||||
if payload.get("type") not in ["access", "refresh"]:
|
||||
logger.warning(f"Invalid token type: {payload.get('type')}")
|
||||
return None
|
||||
|
||||
# Check expiration (jose handles this, but double-check)
|
||||
exp = payload.get("exp")
|
||||
if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(timezone.utc):
|
||||
logger.warning("Token has expired")
|
||||
return None
|
||||
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.warning("Token has expired")
|
||||
|
||||
except JWTError as e:
|
||||
logger.warning(f"JWT validation failed: {e}")
|
||||
return None
|
||||
except jwt.JWTError:
|
||||
logger.warning("Invalid token")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error validating token: {e}")
|
||||
return None
|
||||
Reference in New Issue
Block a user