REFACTOR API gateway fix 8

This commit is contained in:
Urtzi Alfaro
2025-07-26 23:29:57 +02:00
parent 1291d05183
commit 97ae58fb06
8 changed files with 997 additions and 375 deletions

View File

@@ -18,16 +18,50 @@ class JWTHandler:
self.secret_key = secret_key
self.algorithm = algorithm
def create_access_token_from_payload(self, payload: Dict[str, Any]) -> str:
"""
Create JWT ACCESS token from complete payload
✅ FIXED: Only creates access tokens with access token structure
"""
try:
# Ensure this is marked as an access token
payload["type"] = "access"
encoded_jwt = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
logger.debug(f"Created access token with payload keys: {list(payload.keys())}")
return encoded_jwt
except Exception as e:
logger.error(f"Access token creation failed: {e}")
raise ValueError(f"Failed to encode access token: {str(e)}")
def create_refresh_token_from_payload(self, payload: Dict[str, Any]) -> str:
"""
Create JWT REFRESH token from complete payload
✅ FIXED: Only creates refresh tokens with refresh token structure
"""
try:
# Ensure this is marked as a refresh token
payload["type"] = "refresh"
encoded_jwt = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
logger.debug(f"Created refresh token with payload keys: {list(payload.keys())}")
return encoded_jwt
except Exception as e:
logger.error(f"Refresh token creation failed: {e}")
raise ValueError(f"Failed to encode refresh token: {str(e)}")
def create_access_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""
Create JWT access token with STANDARD structure
FIXED: Consistent payload format across all services
Create JWT access token with STANDARD structure (legacy method)
FIXED: Consistent payload format for access tokens
"""
to_encode = {
"sub": user_data["user_id"], # Standard JWT subject claim
"user_id": user_data["user_id"], # Explicit user ID
"email": user_data["email"], # User email
"type": "access" # Token type
"sub": user_data["user_id"],
"user_id": user_data["user_id"],
"email": user_data["email"],
"type": "access"
}
# Add optional fields if present
@@ -35,6 +69,8 @@ class JWTHandler:
to_encode["full_name"] = user_data["full_name"]
if "is_verified" in user_data:
to_encode["is_verified"] = user_data["is_verified"]
if "is_active" in user_data:
to_encode["is_active"] = user_data["is_active"]
# Set expiration
if expires_delta:
@@ -44,7 +80,8 @@ class JWTHandler:
to_encode.update({
"exp": expire,
"iat": datetime.now(timezone.utc)
"iat": datetime.now(timezone.utc),
"iss": "bakery-auth"
})
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
@@ -53,8 +90,8 @@ class JWTHandler:
def create_refresh_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""
Create JWT refresh token with MINIMAL payload
FIXED: Consistent refresh token structure
Create JWT refresh token with MINIMAL payload (legacy method)
FIXED: Consistent refresh token structure, different from access
"""
to_encode = {
"sub": user_data["user_id"],
@@ -62,14 +99,27 @@ class JWTHandler:
"type": "refresh"
}
# Add unique identifier to prevent duplicates
if "jti" in user_data:
to_encode["jti"] = user_data["jti"]
else:
import uuid
to_encode["jti"] = str(uuid.uuid4())
# Include email only if available (optional for refresh tokens)
if "email" in user_data and user_data["email"]:
to_encode["email"] = user_data["email"]
# Set expiration
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(days=7)
expire = datetime.now(timezone.utc) + timedelta(days=30)
to_encode.update({
"exp": expire,
"iat": datetime.now(timezone.utc)
"iat": datetime.now(timezone.utc),
"iss": "bakery-auth"
})
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
@@ -78,95 +128,64 @@ class JWTHandler:
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
"""
Verify and decode JWT token with comprehensive validation
FIXED: Better error handling and validation
Verify and decode JWT token
"""
try:
# Decode token
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
options={"verify_exp": True} # Verify expiration
)
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
# Validate required fields
if not self._validate_payload(payload):
logger.warning("Token payload validation failed")
return None
# Check if token is expired
exp_timestamp = payload.get("exp")
if exp_timestamp:
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
if datetime.now(timezone.utc) > exp_datetime:
logger.debug("Token is expired")
return None
# Check if token is expired (additional check)
exp = payload.get("exp")
if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(timezone.utc):
logger.warning("Token has expired")
return None
logger.debug(f"Token verified successfully for user {payload.get('user_id')}")
logger.debug(f"Token verified successfully, type: {payload.get('type', 'unknown')}")
return payload
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
return None
except jwt.JWTClaimsError as e:
logger.warning(f"Token claims validation failed: {e}")
return None
except jwt.JWTError as e:
logger.warning(f"Token validation failed: {e}")
except JWTError as e:
logger.warning(f"JWT verification failed: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error during token verification: {e}")
logger.error(f"Token verification error: {e}")
return None
def decode_token_unsafe(self, token: str) -> Optional[Dict[str, Any]]:
def decode_token_no_verify(self, token: str) -> Dict[str, Any]:
"""
Decode JWT token without verification (for debugging only)
Decode JWT token without verification (for inspection purposes)
"""
try:
return jwt.decode(
token,
options={"verify_signature": False, "verify_exp": False}
)
# Decode without verification
payload = jwt.decode(token, options={"verify_signature": False})
return payload
except Exception as e:
logger.error(f"Failed to decode token: {e}")
logger.error(f"Token decoding failed: {e}")
raise ValueError("Invalid token format")
def get_token_type(self, token: str) -> Optional[str]:
"""
Get the type of token (access or refresh) without full verification
"""
try:
payload = self.decode_token_no_verify(token)
return payload.get("type")
except Exception:
return None
def _validate_payload(self, payload: Dict[str, Any]) -> bool:
def is_token_expired(self, token: str) -> bool:
"""
Validate JWT payload structure
FIXED: Comprehensive validation for required fields
Check if token is expired without full verification
"""
# Check required fields for all tokens
required_base_fields = ["sub", "user_id", "type", "exp", "iat"]
for field in required_base_fields:
if field not in payload:
logger.warning(f"Missing required field in token: {field}")
return False
# Validate token type
token_type = payload.get("type")
if token_type not in ["access", "refresh"]:
logger.warning(f"Invalid token type: {token_type}")
return False
# Additional validation for access tokens
if token_type == "access":
if "email" not in payload:
logger.warning("Access token missing email field")
return False
# Validate user_id format (should be UUID)
user_id = payload.get("user_id")
if not user_id or not isinstance(user_id, str):
logger.warning("Invalid user_id in token")
return False
# Validate subject matches user_id
if payload.get("sub") != user_id:
logger.warning("Token subject does not match user_id")
return False
return True
try:
payload = self.decode_token_no_verify(token)
exp_timestamp = payload.get("exp")
if exp_timestamp:
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
return datetime.now(timezone.utc) > exp_datetime
return True
except Exception:
return True
def extract_user_id(self, token: str) -> Optional[str]:
"""
@@ -182,20 +201,6 @@ class JWTHandler:
return None
def is_token_expired(self, token: str) -> bool:
"""
Check if token is expired without full verification
"""
try:
payload = self.decode_token_unsafe(token)
if payload and "exp" in payload:
exp = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
return exp < datetime.now(timezone.utc)
except Exception as e:
logger.warning(f"Failed to check token expiration: {e}")
return True # Assume expired if we can't check
def get_token_info(self, token: str) -> Dict[str, Any]:
"""
Get comprehensive token information for debugging