Files
bakery-ia/shared/auth/jwt_handler.py

233 lines
8.2 KiB
Python
Raw Normal View History

2025-07-26 20:04:24 +02:00
# shared/auth/jwt_handler.py
"""
2025-07-26 20:04:24 +02:00
Enhanced JWT Handler with proper token structure and validation
FIXED VERSION - Consistent token format between all services
"""
2025-07-19 17:49:03 +02:00
from jose import jwt, JWTError
2025-07-18 07:46:56 +02:00
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any
2025-07-19 17:49:03 +02:00
import structlog
2025-07-19 17:49:03 +02:00
logger = structlog.get_logger()
class JWTHandler:
2025-07-26 20:04:24 +02:00
"""Enhanced JWT token handling with consistent format"""
def __init__(self, secret_key: str, algorithm: str = "HS256"):
self.secret_key = secret_key
self.algorithm = algorithm
2025-07-19 17:49:03 +02:00
def create_access_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""
2025-07-26 20:04:24 +02:00
Create JWT access token with STANDARD structure
FIXED: Consistent payload format across all services
2025-07-19 17:49:03 +02:00
"""
to_encode = {
2025-07-26 20:04:24 +02:00
"sub": user_data["user_id"], # Standard JWT subject claim
"user_id": user_data["user_id"], # Explicit user ID
"email": user_data["email"], # User email
"type": "access" # Token type
2025-07-19 17:49:03 +02:00
}
2025-07-26 20:04:24 +02:00
# Add optional fields if present
if "full_name" in user_data:
to_encode["full_name"] = user_data["full_name"]
if "is_verified" in user_data:
to_encode["is_verified"] = user_data["is_verified"]
# Set expiration
if expires_delta:
2025-07-19 17:49:03 +02:00
expire = datetime.now(timezone.utc) + expires_delta
else:
2025-07-19 17:49:03 +02:00
expire = datetime.now(timezone.utc) + timedelta(minutes=30)
2025-07-19 17:49:03 +02:00
to_encode.update({
"exp": expire,
"iat": datetime.now(timezone.utc)
})
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
2025-07-19 17:49:03 +02:00
logger.debug(f"Created access token for user {user_data['email']}")
return encoded_jwt
2025-07-19 17:49:03 +02:00
def create_refresh_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
2025-07-26 20:04:24 +02:00
"""
Create JWT refresh token with MINIMAL payload
FIXED: Consistent refresh token structure
"""
2025-07-19 17:49:03 +02:00
to_encode = {
"sub": user_data["user_id"],
"user_id": user_data["user_id"],
"type": "refresh"
}
if expires_delta:
2025-07-19 17:49:03 +02:00
expire = datetime.now(timezone.utc) + expires_delta
else:
2025-07-19 17:49:03 +02:00
expire = datetime.now(timezone.utc) + timedelta(days=7)
2025-07-19 17:49:03 +02:00
to_encode.update({
"exp": expire,
"iat": datetime.now(timezone.utc)
})
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
2025-07-26 20:04:24 +02:00
logger.debug(f"Created refresh token for user {user_data['user_id']}")
return encoded_jwt
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
2025-07-26 20:04:24 +02:00
"""
Verify and decode JWT token with comprehensive validation
FIXED: Better error handling and validation
"""
try:
2025-07-26 20:04:24 +02:00
# Decode token
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
options={"verify_exp": True} # Verify expiration
)
2025-07-19 17:49:03 +02:00
# Validate required fields
2025-07-26 20:04:24 +02:00
if not self._validate_payload(payload):
logger.warning("Token payload validation failed")
2025-07-19 17:49:03 +02:00
return None
2025-07-26 20:04:24 +02:00
# Check if token is expired (additional check)
2025-07-19 17:49:03 +02:00
exp = payload.get("exp")
if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(timezone.utc):
logger.warning("Token has expired")
return None
2025-07-26 20:04:24 +02:00
logger.debug(f"Token verified successfully for user {payload.get('user_id')}")
return payload
2025-07-19 17:49:03 +02:00
2025-07-26 20:04:24 +02:00
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
return None
except jwt.JWTClaimsError as e:
logger.warning(f"Token claims validation failed: {e}")
return None
except jwt.JWTError as e:
logger.warning(f"Token validation failed: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error during token verification: {e}")
return None
2025-07-26 20:04:24 +02:00
def decode_token_unsafe(self, token: str) -> Optional[Dict[str, Any]]:
"""
Decode JWT token without verification (for debugging only)
"""
try:
return jwt.decode(
token,
options={"verify_signature": False, "verify_exp": False}
)
2025-07-19 17:49:03 +02:00
except Exception as e:
2025-07-26 20:04:24 +02:00
logger.error(f"Failed to decode token: {e}")
return None
def _validate_payload(self, payload: Dict[str, Any]) -> bool:
"""
Validate JWT payload structure
FIXED: Comprehensive validation for required fields
"""
# Check required fields for all tokens
required_base_fields = ["sub", "user_id", "type", "exp", "iat"]
for field in required_base_fields:
if field not in payload:
logger.warning(f"Missing required field in token: {field}")
return False
# Validate token type
token_type = payload.get("type")
if token_type not in ["access", "refresh"]:
logger.warning(f"Invalid token type: {token_type}")
return False
# Additional validation for access tokens
if token_type == "access":
if "email" not in payload:
logger.warning("Access token missing email field")
return False
# Validate user_id format (should be UUID)
user_id = payload.get("user_id")
if not user_id or not isinstance(user_id, str):
logger.warning("Invalid user_id in token")
return False
# Validate subject matches user_id
if payload.get("sub") != user_id:
logger.warning("Token subject does not match user_id")
return False
return True
def extract_user_id(self, token: str) -> Optional[str]:
"""
Extract user ID from token without full verification
Useful for quick user identification
"""
try:
payload = self.decode_token_unsafe(token)
if payload:
return payload.get("user_id")
except Exception as e:
logger.warning(f"Failed to extract user ID from token: {e}")
return None
def is_token_expired(self, token: str) -> bool:
"""
Check if token is expired without full verification
"""
try:
payload = self.decode_token_unsafe(token)
if payload and "exp" in payload:
exp = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
return exp < datetime.now(timezone.utc)
except Exception as e:
logger.warning(f"Failed to check token expiration: {e}")
return True # Assume expired if we can't check
def get_token_info(self, token: str) -> Dict[str, Any]:
"""
Get comprehensive token information for debugging
"""
info = {
"valid": False,
"expired": True,
"user_id": None,
"email": None,
"type": None,
"exp": None,
"iat": None
}
try:
# Try unsafe decode first
payload = self.decode_token_unsafe(token)
if payload:
info.update({
"user_id": payload.get("user_id"),
"email": payload.get("email"),
"type": payload.get("type"),
"exp": payload.get("exp"),
"iat": payload.get("iat"),
"expired": self.is_token_expired(token)
})
# Try full verification
verified_payload = self.verify_token(token)
info["valid"] = verified_payload is not None
except Exception as e:
logger.warning(f"Failed to get token info: {e}")
return info