2025-07-26 20:04:24 +02:00
|
|
|
# shared/auth/jwt_handler.py
|
2025-07-17 13:09:24 +02:00
|
|
|
"""
|
2025-07-26 20:04:24 +02:00
|
|
|
Enhanced JWT Handler with proper token structure and validation
|
|
|
|
|
FIXED VERSION - Consistent token format between all services
|
2025-07-17 13:09:24 +02:00
|
|
|
"""
|
|
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
from jose import jwt, JWTError
|
2025-07-18 07:46:56 +02:00
|
|
|
from datetime import datetime, timedelta, timezone
|
2025-07-17 13:09:24 +02:00
|
|
|
from typing import Optional, Dict, Any
|
2025-07-19 17:49:03 +02:00
|
|
|
import structlog
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
logger = structlog.get_logger()
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
class JWTHandler:
|
2025-07-26 20:04:24 +02:00
|
|
|
"""Enhanced JWT token handling with consistent format"""
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
def __init__(self, secret_key: str, algorithm: str = "HS256"):
|
|
|
|
|
self.secret_key = secret_key
|
|
|
|
|
self.algorithm = algorithm
|
|
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
def create_access_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
|
|
|
|
"""
|
2025-07-26 20:04:24 +02:00
|
|
|
Create JWT access token with STANDARD structure
|
|
|
|
|
FIXED: Consistent payload format across all services
|
2025-07-19 17:49:03 +02:00
|
|
|
"""
|
|
|
|
|
to_encode = {
|
2025-07-26 20:04:24 +02:00
|
|
|
"sub": user_data["user_id"], # Standard JWT subject claim
|
|
|
|
|
"user_id": user_data["user_id"], # Explicit user ID
|
|
|
|
|
"email": user_data["email"], # User email
|
|
|
|
|
"type": "access" # Token type
|
2025-07-19 17:49:03 +02:00
|
|
|
}
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-26 20:04:24 +02:00
|
|
|
# Add optional fields if present
|
|
|
|
|
if "full_name" in user_data:
|
|
|
|
|
to_encode["full_name"] = user_data["full_name"]
|
|
|
|
|
if "is_verified" in user_data:
|
|
|
|
|
to_encode["is_verified"] = user_data["is_verified"]
|
|
|
|
|
|
|
|
|
|
# Set expiration
|
2025-07-17 13:09:24 +02:00
|
|
|
if expires_delta:
|
2025-07-19 17:49:03 +02:00
|
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
2025-07-17 13:09:24 +02:00
|
|
|
else:
|
2025-07-19 17:49:03 +02:00
|
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=30)
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
to_encode.update({
|
|
|
|
|
"exp": expire,
|
|
|
|
|
"iat": datetime.now(timezone.utc)
|
|
|
|
|
})
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
2025-07-19 17:49:03 +02:00
|
|
|
logger.debug(f"Created access token for user {user_data['email']}")
|
2025-07-17 13:09:24 +02:00
|
|
|
return encoded_jwt
|
|
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
def create_refresh_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
2025-07-26 20:04:24 +02:00
|
|
|
"""
|
|
|
|
|
Create JWT refresh token with MINIMAL payload
|
|
|
|
|
FIXED: Consistent refresh token structure
|
|
|
|
|
"""
|
2025-07-19 17:49:03 +02:00
|
|
|
to_encode = {
|
|
|
|
|
"sub": user_data["user_id"],
|
|
|
|
|
"user_id": user_data["user_id"],
|
|
|
|
|
"type": "refresh"
|
|
|
|
|
}
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
if expires_delta:
|
2025-07-19 17:49:03 +02:00
|
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
2025-07-17 13:09:24 +02:00
|
|
|
else:
|
2025-07-19 17:49:03 +02:00
|
|
|
expire = datetime.now(timezone.utc) + timedelta(days=7)
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
to_encode.update({
|
|
|
|
|
"exp": expire,
|
|
|
|
|
"iat": datetime.now(timezone.utc)
|
|
|
|
|
})
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
2025-07-26 20:04:24 +02:00
|
|
|
logger.debug(f"Created refresh token for user {user_data['user_id']}")
|
2025-07-17 13:09:24 +02:00
|
|
|
return encoded_jwt
|
|
|
|
|
|
|
|
|
|
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
2025-07-26 20:04:24 +02:00
|
|
|
"""
|
|
|
|
|
Verify and decode JWT token with comprehensive validation
|
|
|
|
|
FIXED: Better error handling and validation
|
|
|
|
|
"""
|
2025-07-17 13:09:24 +02:00
|
|
|
try:
|
2025-07-26 20:04:24 +02:00
|
|
|
# Decode token
|
|
|
|
|
payload = jwt.decode(
|
|
|
|
|
token,
|
|
|
|
|
self.secret_key,
|
|
|
|
|
algorithms=[self.algorithm],
|
|
|
|
|
options={"verify_exp": True} # Verify expiration
|
|
|
|
|
)
|
2025-07-19 17:49:03 +02:00
|
|
|
|
|
|
|
|
# Validate required fields
|
2025-07-26 20:04:24 +02:00
|
|
|
if not self._validate_payload(payload):
|
|
|
|
|
logger.warning("Token payload validation failed")
|
2025-07-19 17:49:03 +02:00
|
|
|
return None
|
|
|
|
|
|
2025-07-26 20:04:24 +02:00
|
|
|
# Check if token is expired (additional check)
|
2025-07-19 17:49:03 +02:00
|
|
|
exp = payload.get("exp")
|
|
|
|
|
if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(timezone.utc):
|
|
|
|
|
logger.warning("Token has expired")
|
|
|
|
|
return None
|
|
|
|
|
|
2025-07-26 20:04:24 +02:00
|
|
|
logger.debug(f"Token verified successfully for user {payload.get('user_id')}")
|
2025-07-17 13:09:24 +02:00
|
|
|
return payload
|
2025-07-19 17:49:03 +02:00
|
|
|
|
2025-07-26 20:04:24 +02:00
|
|
|
except jwt.ExpiredSignatureError:
|
|
|
|
|
logger.warning("Token has expired")
|
|
|
|
|
return None
|
|
|
|
|
except jwt.JWTClaimsError as e:
|
|
|
|
|
logger.warning(f"Token claims validation failed: {e}")
|
|
|
|
|
return None
|
|
|
|
|
except jwt.JWTError as e:
|
|
|
|
|
logger.warning(f"Token validation failed: {e}")
|
|
|
|
|
return None
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Unexpected error during token verification: {e}")
|
2025-07-17 13:09:24 +02:00
|
|
|
return None
|
2025-07-26 20:04:24 +02:00
|
|
|
|
|
|
|
|
def decode_token_unsafe(self, token: str) -> Optional[Dict[str, Any]]:
|
|
|
|
|
"""
|
|
|
|
|
Decode JWT token without verification (for debugging only)
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
return jwt.decode(
|
|
|
|
|
token,
|
|
|
|
|
options={"verify_signature": False, "verify_exp": False}
|
|
|
|
|
)
|
2025-07-19 17:49:03 +02:00
|
|
|
except Exception as e:
|
2025-07-26 20:04:24 +02:00
|
|
|
logger.error(f"Failed to decode token: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _validate_payload(self, payload: Dict[str, Any]) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
Validate JWT payload structure
|
|
|
|
|
FIXED: Comprehensive validation for required fields
|
|
|
|
|
"""
|
|
|
|
|
# Check required fields for all tokens
|
|
|
|
|
required_base_fields = ["sub", "user_id", "type", "exp", "iat"]
|
|
|
|
|
|
|
|
|
|
for field in required_base_fields:
|
|
|
|
|
if field not in payload:
|
|
|
|
|
logger.warning(f"Missing required field in token: {field}")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Validate token type
|
|
|
|
|
token_type = payload.get("type")
|
|
|
|
|
if token_type not in ["access", "refresh"]:
|
|
|
|
|
logger.warning(f"Invalid token type: {token_type}")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Additional validation for access tokens
|
|
|
|
|
if token_type == "access":
|
|
|
|
|
if "email" not in payload:
|
|
|
|
|
logger.warning("Access token missing email field")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Validate user_id format (should be UUID)
|
|
|
|
|
user_id = payload.get("user_id")
|
|
|
|
|
if not user_id or not isinstance(user_id, str):
|
|
|
|
|
logger.warning("Invalid user_id in token")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Validate subject matches user_id
|
|
|
|
|
if payload.get("sub") != user_id:
|
|
|
|
|
logger.warning("Token subject does not match user_id")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def extract_user_id(self, token: str) -> Optional[str]:
|
|
|
|
|
"""
|
|
|
|
|
Extract user ID from token without full verification
|
|
|
|
|
Useful for quick user identification
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
payload = self.decode_token_unsafe(token)
|
|
|
|
|
if payload:
|
|
|
|
|
return payload.get("user_id")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to extract user ID from token: {e}")
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def is_token_expired(self, token: str) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
Check if token is expired without full verification
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
payload = self.decode_token_unsafe(token)
|
|
|
|
|
if payload and "exp" in payload:
|
|
|
|
|
exp = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
|
|
|
|
return exp < datetime.now(timezone.utc)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to check token expiration: {e}")
|
|
|
|
|
|
|
|
|
|
return True # Assume expired if we can't check
|
|
|
|
|
|
|
|
|
|
def get_token_info(self, token: str) -> Dict[str, Any]:
|
|
|
|
|
"""
|
|
|
|
|
Get comprehensive token information for debugging
|
|
|
|
|
"""
|
|
|
|
|
info = {
|
|
|
|
|
"valid": False,
|
|
|
|
|
"expired": True,
|
|
|
|
|
"user_id": None,
|
|
|
|
|
"email": None,
|
|
|
|
|
"type": None,
|
|
|
|
|
"exp": None,
|
|
|
|
|
"iat": None
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Try unsafe decode first
|
|
|
|
|
payload = self.decode_token_unsafe(token)
|
|
|
|
|
if payload:
|
|
|
|
|
info.update({
|
|
|
|
|
"user_id": payload.get("user_id"),
|
|
|
|
|
"email": payload.get("email"),
|
|
|
|
|
"type": payload.get("type"),
|
|
|
|
|
"exp": payload.get("exp"),
|
|
|
|
|
"iat": payload.get("iat"),
|
|
|
|
|
"expired": self.is_token_expired(token)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
# Try full verification
|
|
|
|
|
verified_payload = self.verify_token(token)
|
|
|
|
|
info["valid"] = verified_payload is not None
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to get token info: {e}")
|
|
|
|
|
|
|
|
|
|
return info
|