# shared/auth/jwt_handler.py """ Enhanced JWT Handler with proper token structure and validation FIXED VERSION - Consistent token format between all services """ from jose import jwt, JWTError from datetime import datetime, timedelta, timezone from typing import Optional, Dict, Any import structlog logger = structlog.get_logger() class JWTHandler: """Enhanced JWT token handling with consistent format""" def __init__(self, secret_key: str, algorithm: str = "HS256"): self.secret_key = secret_key self.algorithm = algorithm def create_access_token_from_payload(self, payload: Dict[str, Any]) -> str: """ Create JWT ACCESS token from complete payload ✅ FIXED: Only creates access tokens with access token structure """ try: # Ensure this is marked as an access token payload["type"] = "access" encoded_jwt = jwt.encode(payload, self.secret_key, algorithm=self.algorithm) logger.debug(f"Created access token with payload keys: {list(payload.keys())}") return encoded_jwt except Exception as e: logger.error(f"Access token creation failed: {e}") raise ValueError(f"Failed to encode access token: {str(e)}") def create_refresh_token_from_payload(self, payload: Dict[str, Any]) -> str: """ Create JWT REFRESH token from complete payload ✅ FIXED: Only creates refresh tokens with refresh token structure """ try: # Ensure this is marked as a refresh token payload["type"] = "refresh" encoded_jwt = jwt.encode(payload, self.secret_key, algorithm=self.algorithm) logger.debug(f"Created refresh token with payload keys: {list(payload.keys())}") return encoded_jwt except Exception as e: logger.error(f"Refresh token creation failed: {e}") raise ValueError(f"Failed to encode refresh token: {str(e)}") def create_access_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: """ Create JWT access token with STANDARD structure (legacy method) ✅ FIXED: Consistent payload format for access tokens """ to_encode = { "sub": user_data["user_id"], "user_id": user_data["user_id"], "email": user_data["email"], "type": "access" } # Add optional fields if present if "full_name" in user_data: to_encode["full_name"] = user_data["full_name"] if "is_verified" in user_data: to_encode["is_verified"] = user_data["is_verified"] if "is_active" in user_data: to_encode["is_active"] = user_data["is_active"] # Set expiration if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(minutes=30) to_encode.update({ "exp": expire, "iat": datetime.now(timezone.utc), "iss": "bakery-auth" }) encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) logger.debug(f"Created access token for user {user_data['email']}") return encoded_jwt def create_refresh_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: """ Create JWT refresh token with MINIMAL payload (legacy method) ✅ FIXED: Consistent refresh token structure, different from access """ to_encode = { "sub": user_data["user_id"], "user_id": user_data["user_id"], "type": "refresh" } # Add unique identifier to prevent duplicates if "jti" in user_data: to_encode["jti"] = user_data["jti"] else: import uuid to_encode["jti"] = str(uuid.uuid4()) # Include email only if available (optional for refresh tokens) if "email" in user_data and user_data["email"]: to_encode["email"] = user_data["email"] # Set expiration if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(days=30) to_encode.update({ "exp": expire, "iat": datetime.now(timezone.utc), "iss": "bakery-auth" }) encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) logger.debug(f"Created refresh token for user {user_data['user_id']}") return encoded_jwt def verify_token(self, token: str) -> Optional[Dict[str, Any]]: """ Verify and decode JWT token """ try: payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) # Check if token is expired exp_timestamp = payload.get("exp") if exp_timestamp: exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc) if datetime.now(timezone.utc) > exp_datetime: logger.debug("Token is expired") return None logger.debug(f"Token verified successfully, type: {payload.get('type', 'unknown')}") return payload except JWTError as e: logger.warning(f"JWT verification failed: {e}") return None except Exception as e: logger.error(f"Token verification error: {e}") return None def decode_token_no_verify(self, token: str) -> Dict[str, Any]: """ Decode JWT token without verification (for inspection purposes) """ try: # Decode without verification - need to provide key but disable verification payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm], options={"verify_signature": False}) return payload except Exception as e: logger.error(f"Token decoding failed: {e}") raise ValueError("Invalid token format") def get_token_type(self, token: str) -> Optional[str]: """ Get the type of token (access or refresh) without full verification """ try: payload = self.decode_token_no_verify(token) return payload.get("type") except Exception: return None def is_token_expired(self, token: str) -> bool: """ Check if token is expired without full verification """ try: payload = self.decode_token_no_verify(token) exp_timestamp = payload.get("exp") if exp_timestamp: exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc) return datetime.now(timezone.utc) > exp_datetime return True except Exception: return True def extract_user_id(self, token: str) -> Optional[str]: """ Extract user ID from token without full verification Useful for quick user identification """ try: payload = self.decode_token_no_verify(token) if payload: return payload.get("user_id") except Exception as e: logger.warning(f"Failed to extract user ID from token: {e}") return None def get_token_info(self, token: str) -> Dict[str, Any]: """ Get comprehensive token information for debugging """ info = { "valid": False, "expired": True, "user_id": None, "email": None, "type": None, "exp": None, "iat": None } try: # Try unsafe decode first payload = self.decode_token_no_verify(token) if payload: info.update({ "user_id": payload.get("user_id"), "email": payload.get("email"), "type": payload.get("type"), "exp": payload.get("exp"), "iat": payload.get("iat"), "expired": self.is_token_expired(token) }) # Try full verification verified_payload = self.verify_token(token) info["valid"] = verified_payload is not None except Exception as e: logger.warning(f"Failed to get token info: {e}") return info