# shared/auth/jwt_handler.py """ Enhanced JWT Handler with proper token structure and validation FIXED VERSION - Consistent token format between all services """ from jose import jwt, JWTError from datetime import datetime, timedelta, timezone from typing import Optional, Dict, Any import structlog logger = structlog.get_logger() class JWTHandler: """Enhanced JWT token handling with consistent format""" def __init__(self, secret_key: str, algorithm: str = "HS256"): self.secret_key = secret_key self.algorithm = algorithm def create_access_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: """ Create JWT access token with STANDARD structure FIXED: Consistent payload format across all services """ to_encode = { "sub": user_data["user_id"], # Standard JWT subject claim "user_id": user_data["user_id"], # Explicit user ID "email": user_data["email"], # User email "type": "access" # Token type } # Add optional fields if present if "full_name" in user_data: to_encode["full_name"] = user_data["full_name"] if "is_verified" in user_data: to_encode["is_verified"] = user_data["is_verified"] # Set expiration if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(minutes=30) to_encode.update({ "exp": expire, "iat": datetime.now(timezone.utc) }) encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) logger.debug(f"Created access token for user {user_data['email']}") return encoded_jwt def create_refresh_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: """ Create JWT refresh token with MINIMAL payload FIXED: Consistent refresh token structure """ to_encode = { "sub": user_data["user_id"], "user_id": user_data["user_id"], "type": "refresh" } if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(days=7) to_encode.update({ "exp": expire, "iat": datetime.now(timezone.utc) }) encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) logger.debug(f"Created refresh token for user {user_data['user_id']}") return encoded_jwt def verify_token(self, token: str) -> Optional[Dict[str, Any]]: """ Verify and decode JWT token with comprehensive validation FIXED: Better error handling and validation """ try: # Decode token payload = jwt.decode( token, self.secret_key, algorithms=[self.algorithm], options={"verify_exp": True} # Verify expiration ) # Validate required fields if not self._validate_payload(payload): logger.warning("Token payload validation failed") return None # Check if token is expired (additional check) exp = payload.get("exp") if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(timezone.utc): logger.warning("Token has expired") return None logger.debug(f"Token verified successfully for user {payload.get('user_id')}") return payload except jwt.ExpiredSignatureError: logger.warning("Token has expired") return None except jwt.JWTClaimsError as e: logger.warning(f"Token claims validation failed: {e}") return None except jwt.JWTError as e: logger.warning(f"Token validation failed: {e}") return None except Exception as e: logger.error(f"Unexpected error during token verification: {e}") return None def decode_token_unsafe(self, token: str) -> Optional[Dict[str, Any]]: """ Decode JWT token without verification (for debugging only) """ try: return jwt.decode( token, options={"verify_signature": False, "verify_exp": False} ) except Exception as e: logger.error(f"Failed to decode token: {e}") return None def _validate_payload(self, payload: Dict[str, Any]) -> bool: """ Validate JWT payload structure FIXED: Comprehensive validation for required fields """ # Check required fields for all tokens required_base_fields = ["sub", "user_id", "type", "exp", "iat"] for field in required_base_fields: if field not in payload: logger.warning(f"Missing required field in token: {field}") return False # Validate token type token_type = payload.get("type") if token_type not in ["access", "refresh"]: logger.warning(f"Invalid token type: {token_type}") return False # Additional validation for access tokens if token_type == "access": if "email" not in payload: logger.warning("Access token missing email field") return False # Validate user_id format (should be UUID) user_id = payload.get("user_id") if not user_id or not isinstance(user_id, str): logger.warning("Invalid user_id in token") return False # Validate subject matches user_id if payload.get("sub") != user_id: logger.warning("Token subject does not match user_id") return False return True def extract_user_id(self, token: str) -> Optional[str]: """ Extract user ID from token without full verification Useful for quick user identification """ try: payload = self.decode_token_unsafe(token) if payload: return payload.get("user_id") except Exception as e: logger.warning(f"Failed to extract user ID from token: {e}") return None def is_token_expired(self, token: str) -> bool: """ Check if token is expired without full verification """ try: payload = self.decode_token_unsafe(token) if payload and "exp" in payload: exp = datetime.fromtimestamp(payload["exp"], tz=timezone.utc) return exp < datetime.now(timezone.utc) except Exception as e: logger.warning(f"Failed to check token expiration: {e}") return True # Assume expired if we can't check def get_token_info(self, token: str) -> Dict[str, Any]: """ Get comprehensive token information for debugging """ info = { "valid": False, "expired": True, "user_id": None, "email": None, "type": None, "exp": None, "iat": None } try: # Try unsafe decode first payload = self.decode_token_unsafe(token) if payload: info.update({ "user_id": payload.get("user_id"), "email": payload.get("email"), "type": payload.get("type"), "exp": payload.get("exp"), "iat": payload.get("iat"), "expired": self.is_token_expired(token) }) # Try full verification verified_payload = self.verify_token(token) info["valid"] = verified_payload is not None except Exception as e: logger.warning(f"Failed to get token info: {e}") return info