2025-07-19 17:49:03 +02:00
|
|
|
# shared/auth/jwt_handler.py - IMPROVED VERSION
|
2025-07-17 13:09:24 +02:00
|
|
|
"""
|
2025-07-19 17:49:03 +02:00
|
|
|
Enhanced JWT Handler with proper token structure
|
2025-07-17 13:09:24 +02:00
|
|
|
"""
|
|
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
from jose import jwt, JWTError
|
2025-07-18 07:46:56 +02:00
|
|
|
from datetime import datetime, timedelta, timezone
|
2025-07-17 13:09:24 +02:00
|
|
|
from typing import Optional, Dict, Any
|
2025-07-19 17:49:03 +02:00
|
|
|
import structlog
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
logger = structlog.get_logger()
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
class JWTHandler:
|
2025-07-19 17:49:03 +02:00
|
|
|
"""Enhanced JWT token handling"""
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
def __init__(self, secret_key: str, algorithm: str = "HS256"):
|
|
|
|
|
self.secret_key = secret_key
|
|
|
|
|
self.algorithm = algorithm
|
|
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
def create_access_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Create JWT access token WITHOUT tenant_id
|
|
|
|
|
Tenant context is determined per request, not stored in token
|
|
|
|
|
"""
|
|
|
|
|
to_encode = {
|
|
|
|
|
"sub": user_data["user_id"], # Standard JWT subject
|
|
|
|
|
"user_id": user_data["user_id"],
|
|
|
|
|
"email": user_data["email"],
|
|
|
|
|
"type": "access"
|
|
|
|
|
}
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
if expires_delta:
|
2025-07-19 17:49:03 +02:00
|
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
2025-07-17 13:09:24 +02:00
|
|
|
else:
|
2025-07-19 17:49:03 +02:00
|
|
|
expire = datetime.now(timezone.utc) + timedelta(minutes=30)
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
to_encode.update({
|
|
|
|
|
"exp": expire,
|
|
|
|
|
"iat": datetime.now(timezone.utc)
|
|
|
|
|
})
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
2025-07-19 17:49:03 +02:00
|
|
|
logger.debug(f"Created access token for user {user_data['email']}")
|
2025-07-17 13:09:24 +02:00
|
|
|
return encoded_jwt
|
|
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
def create_refresh_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
2025-07-17 13:09:24 +02:00
|
|
|
"""Create JWT refresh token"""
|
2025-07-19 17:49:03 +02:00
|
|
|
to_encode = {
|
|
|
|
|
"sub": user_data["user_id"],
|
|
|
|
|
"user_id": user_data["user_id"],
|
|
|
|
|
"email": user_data["email"],
|
|
|
|
|
"type": "refresh"
|
|
|
|
|
}
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
if expires_delta:
|
2025-07-19 17:49:03 +02:00
|
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
2025-07-17 13:09:24 +02:00
|
|
|
else:
|
2025-07-19 17:49:03 +02:00
|
|
|
expire = datetime.now(timezone.utc) + timedelta(days=7)
|
2025-07-17 13:09:24 +02:00
|
|
|
|
2025-07-19 17:49:03 +02:00
|
|
|
to_encode.update({
|
|
|
|
|
"exp": expire,
|
|
|
|
|
"iat": datetime.now(timezone.utc)
|
|
|
|
|
})
|
2025-07-17 13:09:24 +02:00
|
|
|
|
|
|
|
|
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
|
|
|
|
return encoded_jwt
|
|
|
|
|
|
|
|
|
|
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
2025-07-19 17:49:03 +02:00
|
|
|
"""Verify and decode JWT token with comprehensive validation"""
|
2025-07-17 13:09:24 +02:00
|
|
|
try:
|
|
|
|
|
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
2025-07-19 17:49:03 +02:00
|
|
|
|
|
|
|
|
# Validate required fields
|
|
|
|
|
required_fields = ["user_id", "email", "exp", "type"]
|
|
|
|
|
if not all(field in payload for field in required_fields):
|
|
|
|
|
logger.warning(f"Token missing required fields: {required_fields}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# Validate token type
|
|
|
|
|
if payload.get("type") not in ["access", "refresh"]:
|
|
|
|
|
logger.warning(f"Invalid token type: {payload.get('type')}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# Check expiration (jose handles this, but double-check)
|
|
|
|
|
exp = payload.get("exp")
|
|
|
|
|
if exp and datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(timezone.utc):
|
|
|
|
|
logger.warning("Token has expired")
|
|
|
|
|
return None
|
|
|
|
|
|
2025-07-17 13:09:24 +02:00
|
|
|
return payload
|
2025-07-19 17:49:03 +02:00
|
|
|
|
|
|
|
|
except JWTError as e:
|
|
|
|
|
logger.warning(f"JWT validation failed: {e}")
|
2025-07-17 13:09:24 +02:00
|
|
|
return None
|
2025-07-19 17:49:03 +02:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Unexpected error validating token: {e}")
|
2025-07-17 13:09:24 +02:00
|
|
|
return None
|