# shared/auth/jwt_handler.py - IMPROVED VERSION """ Enhanced JWT Handler with proper token structure """ from jose import jwt, JWTError from datetime import datetime, timedelta, timezone from typing import Optional, Dict, Any import structlog logger = structlog.get_logger() class JWTHandler: """Enhanced JWT token handling""" def __init__(self, secret_key: str, algorithm: str = "HS256"): self.secret_key = secret_key self.algorithm = algorithm def create_access_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: """ Create JWT access token WITHOUT tenant_id Tenant context is determined per request, not stored in token """ to_encode = { "sub": user_data["user_id"], # Standard JWT subject "user_id": user_data["user_id"], "email": user_data["email"], "type": "access" } if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(minutes=30) to_encode.update({ "exp": expire, "iat": datetime.now(timezone.utc) }) encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) logger.debug(f"Created access token for user {user_data['email']}") return encoded_jwt def create_refresh_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: """Create JWT refresh token""" to_encode = { "sub": user_data["user_id"], "user_id": user_data["user_id"], "email": user_data["email"], "type": "refresh" } if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(days=7) to_encode.update({ "exp": expire, "iat": datetime.now(timezone.utc) }) encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) return encoded_jwt def verify_token(self, token: str) -> Optional[Dict[str, Any]]: """Verify and decode JWT token with comprehensive validation""" try: payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) # Validate required fields required_fields = ["user_id", "email", "exp", "type"] if not all(field in payload for field in required_fields): logger.warning(f"Token missing required fields: {required_fields}") return None # Validate token type if payload.get("type") not in ["access", "refresh"]: logger.warning(f"Invalid token type: {payload.get('type')}") return None # Check expiration (jose handles this, but double-check) exp = payload.get("exp") if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(timezone.utc): logger.warning("Token has expired") return None return payload except JWTError as e: logger.warning(f"JWT validation failed: {e}") return None except Exception as e: logger.error(f"Unexpected error validating token: {e}") return None