292 lines
11 KiB
Python
Executable File
292 lines
11 KiB
Python
Executable File
# shared/auth/jwt_handler.py
|
|
"""
|
|
Enhanced JWT Handler with proper token structure and validation
|
|
FIXED VERSION - Consistent token format between all services
|
|
"""
|
|
|
|
from jose import jwt, JWTError
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional, Dict, Any
|
|
import structlog
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
class JWTHandler:
|
|
"""Enhanced JWT token handling with consistent format"""
|
|
|
|
def __init__(self, secret_key: str, algorithm: str = "HS256"):
|
|
self.secret_key = secret_key
|
|
self.algorithm = algorithm
|
|
|
|
def create_access_token_from_payload(self, payload: Dict[str, Any]) -> str:
|
|
"""
|
|
Create JWT ACCESS token from complete payload
|
|
✅ FIXED: Only creates access tokens with access token structure
|
|
"""
|
|
try:
|
|
# Ensure this is marked as an access token
|
|
payload["type"] = "access"
|
|
|
|
encoded_jwt = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
|
logger.debug(f"Created access token with payload keys: {list(payload.keys())}")
|
|
return encoded_jwt
|
|
|
|
except Exception as e:
|
|
logger.error(f"Access token creation failed: {e}")
|
|
raise ValueError(f"Failed to encode access token: {str(e)}")
|
|
|
|
def create_refresh_token_from_payload(self, payload: Dict[str, Any]) -> str:
|
|
"""
|
|
Create JWT REFRESH token from complete payload
|
|
✅ FIXED: Only creates refresh tokens with refresh token structure
|
|
"""
|
|
try:
|
|
# Ensure this is marked as a refresh token
|
|
payload["type"] = "refresh"
|
|
|
|
encoded_jwt = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
|
logger.debug(f"Created refresh token with payload keys: {list(payload.keys())}")
|
|
return encoded_jwt
|
|
|
|
except Exception as e:
|
|
logger.error(f"Refresh token creation failed: {e}")
|
|
raise ValueError(f"Failed to encode refresh token: {str(e)}")
|
|
|
|
def create_access_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
|
"""
|
|
Create JWT access token with STANDARD structure (legacy method)
|
|
✅ FIXED: Consistent payload format for access tokens
|
|
"""
|
|
to_encode = {
|
|
"sub": user_data["user_id"],
|
|
"user_id": user_data["user_id"],
|
|
"email": user_data["email"],
|
|
"type": "access"
|
|
}
|
|
|
|
# Add optional fields if present
|
|
if "full_name" in user_data:
|
|
to_encode["full_name"] = user_data["full_name"]
|
|
if "is_verified" in user_data:
|
|
to_encode["is_verified"] = user_data["is_verified"]
|
|
if "is_active" in user_data:
|
|
to_encode["is_active"] = user_data["is_active"]
|
|
|
|
# Set expiration
|
|
if expires_delta:
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
else:
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=30)
|
|
|
|
to_encode.update({
|
|
"exp": expire,
|
|
"iat": datetime.now(timezone.utc),
|
|
"iss": "bakery-auth"
|
|
})
|
|
|
|
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
|
logger.debug(f"Created access token for user {user_data['email']}")
|
|
return encoded_jwt
|
|
|
|
def create_refresh_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
|
"""
|
|
Create JWT refresh token with MINIMAL payload (legacy method)
|
|
✅ FIXED: Consistent refresh token structure, different from access
|
|
"""
|
|
to_encode = {
|
|
"sub": user_data["user_id"],
|
|
"user_id": user_data["user_id"],
|
|
"type": "refresh"
|
|
}
|
|
|
|
# Add unique identifier to prevent duplicates
|
|
if "jti" in user_data:
|
|
to_encode["jti"] = user_data["jti"]
|
|
else:
|
|
import uuid
|
|
to_encode["jti"] = str(uuid.uuid4())
|
|
|
|
# Include email only if available (optional for refresh tokens)
|
|
if "email" in user_data and user_data["email"]:
|
|
to_encode["email"] = user_data["email"]
|
|
|
|
# Set expiration
|
|
if expires_delta:
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
else:
|
|
expire = datetime.now(timezone.utc) + timedelta(days=30)
|
|
|
|
to_encode.update({
|
|
"exp": expire,
|
|
"iat": datetime.now(timezone.utc),
|
|
"iss": "bakery-auth"
|
|
})
|
|
|
|
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
|
logger.debug(f"Created refresh token for user {user_data['user_id']}")
|
|
return encoded_jwt
|
|
|
|
def create_service_token(
|
|
self,
|
|
service_name: str,
|
|
expires_delta: Optional[timedelta] = None,
|
|
tenant_id: Optional[str] = None
|
|
) -> str:
|
|
"""
|
|
Create JWT SERVICE token for inter-service communication
|
|
✅ UNIFIED: Single source of truth for all service token creation
|
|
✅ ENHANCED: Supports tenant context for tenant-scoped operations
|
|
|
|
Args:
|
|
service_name: Name of the service (e.g., 'auth-service', 'demo-session')
|
|
expires_delta: Optional expiration time (defaults to 1 hour for inter-service calls)
|
|
tenant_id: Optional tenant ID for tenant-scoped service operations
|
|
|
|
Returns:
|
|
Encoded JWT service token
|
|
"""
|
|
to_encode = {
|
|
"sub": service_name,
|
|
"user_id": f"{service_name}-service",
|
|
"email": f"{service_name}-service@internal",
|
|
"service": service_name,
|
|
"type": "service",
|
|
"role": "admin", # Services have admin privileges
|
|
"is_service": True,
|
|
"full_name": f"{service_name.title()} Service",
|
|
"is_verified": True,
|
|
"is_active": True
|
|
}
|
|
|
|
# Include tenant context when provided for tenant-scoped operations
|
|
if tenant_id:
|
|
to_encode["tenant_id"] = tenant_id
|
|
|
|
# Set expiration (default to 1 hour for inter-service calls)
|
|
if expires_delta:
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
else:
|
|
expire = datetime.now(timezone.utc) + timedelta(hours=1) # 1 hour default
|
|
|
|
to_encode.update({
|
|
"exp": expire,
|
|
"iat": datetime.now(timezone.utc),
|
|
"iss": "bakery-auth"
|
|
})
|
|
|
|
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
|
logger.debug(f"Created service token for service {service_name}", tenant_id=tenant_id)
|
|
return encoded_jwt
|
|
|
|
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Verify and decode JWT token
|
|
"""
|
|
try:
|
|
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
|
|
|
# Check if token is expired
|
|
exp_timestamp = payload.get("exp")
|
|
if exp_timestamp:
|
|
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
|
|
if datetime.now(timezone.utc) > exp_datetime:
|
|
logger.debug("Token is expired")
|
|
return None
|
|
|
|
logger.debug(f"Token verified successfully, type: {payload.get('type', 'unknown')}")
|
|
return payload
|
|
|
|
except JWTError as e:
|
|
logger.warning(f"JWT verification failed: {e}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Token verification error: {e}")
|
|
return None
|
|
|
|
def decode_token_no_verify(self, token: str) -> Dict[str, Any]:
|
|
"""
|
|
Decode JWT token without verification (for inspection purposes)
|
|
"""
|
|
try:
|
|
# Decode without verification - need to provide key but disable verification
|
|
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm], options={"verify_signature": False})
|
|
return payload
|
|
except Exception as e:
|
|
logger.error(f"Token decoding failed: {e}")
|
|
raise ValueError("Invalid token format")
|
|
|
|
def get_token_type(self, token: str) -> Optional[str]:
|
|
"""
|
|
Get the type of token (access or refresh) without full verification
|
|
"""
|
|
try:
|
|
payload = self.decode_token_no_verify(token)
|
|
return payload.get("type")
|
|
except Exception:
|
|
return None
|
|
|
|
def is_token_expired(self, token: str) -> bool:
|
|
"""
|
|
Check if token is expired without full verification
|
|
"""
|
|
try:
|
|
payload = self.decode_token_no_verify(token)
|
|
exp_timestamp = payload.get("exp")
|
|
if exp_timestamp:
|
|
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
|
|
return datetime.now(timezone.utc) > exp_datetime
|
|
return True
|
|
except Exception:
|
|
return True
|
|
|
|
def extract_user_id(self, token: str) -> Optional[str]:
|
|
"""
|
|
Extract user ID from token without full verification
|
|
Useful for quick user identification
|
|
"""
|
|
try:
|
|
payload = self.decode_token_no_verify(token)
|
|
if payload:
|
|
return payload.get("user_id")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to extract user ID from token: {e}")
|
|
|
|
return None
|
|
|
|
|
|
|
|
def get_token_info(self, token: str) -> Dict[str, Any]:
|
|
"""
|
|
Get comprehensive token information for debugging
|
|
"""
|
|
info = {
|
|
"valid": False,
|
|
"expired": True,
|
|
"user_id": None,
|
|
"email": None,
|
|
"type": None,
|
|
"exp": None,
|
|
"iat": None
|
|
}
|
|
|
|
try:
|
|
# Try unsafe decode first
|
|
payload = self.decode_token_no_verify(token)
|
|
if payload:
|
|
info.update({
|
|
"user_id": payload.get("user_id"),
|
|
"email": payload.get("email"),
|
|
"type": payload.get("type"),
|
|
"exp": payload.get("exp"),
|
|
"iat": payload.get("iat"),
|
|
"expired": self.is_token_expired(token)
|
|
})
|
|
|
|
# Try full verification
|
|
verified_payload = self.verify_token(token)
|
|
info["valid"] = verified_payload is not None
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get token info: {e}")
|
|
|
|
return info |