Initial commit - production deployment
This commit is contained in:
0
shared/auth/__init__.py
Executable file
0
shared/auth/__init__.py
Executable file
478
shared/auth/access_control.py
Executable file
478
shared/auth/access_control.py
Executable file
@@ -0,0 +1,478 @@
|
||||
"""
|
||||
Subscription Tier and Role-Based Access Control Decorators
|
||||
Provides unified access control across all microservices
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import List, Callable, Dict, Any, Optional
|
||||
from fastapi import HTTPException, status, Request, Depends
|
||||
import structlog
|
||||
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SubscriptionTier(Enum):
|
||||
"""
|
||||
Subscription tier hierarchy
|
||||
Matches project-wide subscription plans in tenant service
|
||||
"""
|
||||
STARTER = "starter"
|
||||
PROFESSIONAL = "professional"
|
||||
ENTERPRISE = "enterprise"
|
||||
|
||||
|
||||
class UserRole(Enum):
|
||||
"""
|
||||
User role hierarchy
|
||||
Matches project-wide role definitions in tenant member model
|
||||
"""
|
||||
VIEWER = "viewer"
|
||||
MEMBER = "member"
|
||||
ADMIN = "admin"
|
||||
OWNER = "owner"
|
||||
|
||||
|
||||
# Tier hierarchy for comparison (higher number = higher tier)
|
||||
TIER_HIERARCHY = {
|
||||
SubscriptionTier.STARTER: 1,
|
||||
SubscriptionTier.PROFESSIONAL: 2,
|
||||
SubscriptionTier.ENTERPRISE: 3,
|
||||
}
|
||||
|
||||
# Role hierarchy for comparison (higher number = more permissions)
|
||||
ROLE_HIERARCHY = {
|
||||
UserRole.VIEWER: 1,
|
||||
UserRole.MEMBER: 2,
|
||||
UserRole.ADMIN: 3,
|
||||
UserRole.OWNER: 4,
|
||||
}
|
||||
|
||||
|
||||
def check_tier_access(user_tier: str, required_tiers: List[str]) -> bool:
|
||||
"""
|
||||
Check if user's subscription tier meets the requirement
|
||||
|
||||
Args:
|
||||
user_tier: Current user's subscription tier
|
||||
required_tiers: List of allowed tiers
|
||||
|
||||
Returns:
|
||||
bool: True if access is allowed
|
||||
"""
|
||||
try:
|
||||
user_tier_enum = SubscriptionTier(user_tier.lower())
|
||||
user_tier_level = TIER_HIERARCHY.get(user_tier_enum, 0)
|
||||
|
||||
# Get minimum required tier level
|
||||
min_required_level = min(
|
||||
TIER_HIERARCHY.get(SubscriptionTier(tier.lower()), 999)
|
||||
for tier in required_tiers
|
||||
)
|
||||
|
||||
return user_tier_level >= min_required_level
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.warning("Invalid tier comparison", user_tier=user_tier, required=required_tiers, error=str(e))
|
||||
return False
|
||||
|
||||
|
||||
def check_role_access(user_role: str, required_roles: List[str]) -> bool:
|
||||
"""
|
||||
Check if user's role meets the requirement
|
||||
|
||||
Args:
|
||||
user_role: Current user's role
|
||||
required_roles: List of allowed roles
|
||||
|
||||
Returns:
|
||||
bool: True if access is allowed
|
||||
"""
|
||||
try:
|
||||
user_role_enum = UserRole(user_role.lower())
|
||||
user_role_level = ROLE_HIERARCHY.get(user_role_enum, 0)
|
||||
|
||||
# Get minimum required role level
|
||||
min_required_level = min(
|
||||
ROLE_HIERARCHY.get(UserRole(role.lower()), 999)
|
||||
for role in required_roles
|
||||
)
|
||||
|
||||
return user_role_level >= min_required_level
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.warning("Invalid role comparison", user_role=user_role, required=required_roles, error=str(e))
|
||||
return False
|
||||
|
||||
|
||||
def require_subscription_tier(allowed_tiers: List[str]):
|
||||
"""
|
||||
Decorator to enforce subscription tier access control
|
||||
|
||||
Usage:
|
||||
@router.get("/analytics/advanced")
|
||||
@require_subscription_tier(['professional', 'enterprise'])
|
||||
async def get_advanced_analytics(...):
|
||||
...
|
||||
|
||||
Args:
|
||||
allowed_tiers: List of subscription tiers allowed to access this endpoint
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Get current user from kwargs (injected by get_current_user_dep)
|
||||
current_user = kwargs.get('current_user')
|
||||
|
||||
if not current_user:
|
||||
# Try to find in args
|
||||
for arg in args:
|
||||
if isinstance(arg, dict) and 'user_id' in arg:
|
||||
current_user = arg
|
||||
break
|
||||
|
||||
if not current_user:
|
||||
logger.error("Current user not found in request context")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
# Get tenant's subscription tier from user context
|
||||
# The gateway should inject this information
|
||||
subscription_tier = current_user.get('subscription_tier')
|
||||
|
||||
if not subscription_tier:
|
||||
logger.warning("Subscription tier not found in user context", user_id=current_user.get('user_id'))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Subscription information unavailable"
|
||||
)
|
||||
|
||||
# Check tier access
|
||||
has_access = check_tier_access(subscription_tier, allowed_tiers)
|
||||
|
||||
if not has_access:
|
||||
logger.warning(
|
||||
"Subscription tier access denied",
|
||||
user_tier=subscription_tier,
|
||||
required_tiers=allowed_tiers,
|
||||
user_id=current_user.get('user_id')
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"error": "subscription_tier_insufficient",
|
||||
"message": f"This feature requires a {' or '.join(allowed_tiers)} subscription plan",
|
||||
"current_plan": subscription_tier,
|
||||
"required_plans": allowed_tiers,
|
||||
"upgrade_url": "/app/settings/profile"
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug("Subscription tier check passed", tier=subscription_tier, required=allowed_tiers)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def require_user_role(allowed_roles: List[str]):
|
||||
"""
|
||||
Decorator to enforce role-based access control
|
||||
|
||||
Usage:
|
||||
@router.delete("/ingredients/{id}")
|
||||
@require_user_role(['admin', 'manager'])
|
||||
async def delete_ingredient(...):
|
||||
...
|
||||
|
||||
Args:
|
||||
allowed_roles: List of user roles allowed to access this endpoint
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Get current user from kwargs
|
||||
current_user = kwargs.get('current_user')
|
||||
|
||||
if not current_user:
|
||||
# Try to find in args
|
||||
for arg in args:
|
||||
if isinstance(arg, dict) and 'user_id' in arg:
|
||||
current_user = arg
|
||||
break
|
||||
|
||||
if not current_user:
|
||||
logger.error("Current user not found in request context")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
# Get user's role
|
||||
user_role = current_user.get('role', 'user')
|
||||
|
||||
# Check role access
|
||||
has_access = check_role_access(user_role, allowed_roles)
|
||||
|
||||
if not has_access:
|
||||
logger.warning(
|
||||
"Role-based access denied",
|
||||
user_role=user_role,
|
||||
required_roles=allowed_roles,
|
||||
user_id=current_user.get('user_id')
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"error": "insufficient_permissions",
|
||||
"message": f"This action requires {' or '.join(allowed_roles)} role",
|
||||
"current_role": user_role,
|
||||
"required_roles": allowed_roles
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug("Role check passed", role=user_role, required=allowed_roles)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def require_tier_and_role(
|
||||
allowed_tiers: List[str],
|
||||
allowed_roles: List[str]
|
||||
):
|
||||
"""
|
||||
Combined decorator for both tier and role enforcement
|
||||
|
||||
Usage:
|
||||
@router.post("/analytics/custom-report")
|
||||
@require_tier_and_role(['professional', 'enterprise'], ['admin', 'manager'])
|
||||
async def create_custom_report(...):
|
||||
...
|
||||
|
||||
Args:
|
||||
allowed_tiers: List of subscription tiers allowed
|
||||
allowed_roles: List of user roles allowed
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Get current user from kwargs
|
||||
current_user = kwargs.get('current_user')
|
||||
|
||||
if not current_user:
|
||||
# Try to find in args
|
||||
for arg in args:
|
||||
if isinstance(arg, dict) and 'user_id' in arg:
|
||||
current_user = arg
|
||||
break
|
||||
|
||||
if not current_user:
|
||||
logger.error("Current user not found in request context")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
# Check subscription tier
|
||||
subscription_tier = current_user.get('subscription_tier')
|
||||
if subscription_tier:
|
||||
tier_access = check_tier_access(subscription_tier, allowed_tiers)
|
||||
if not tier_access:
|
||||
logger.warning(
|
||||
"Combined access control: tier check failed",
|
||||
user_tier=subscription_tier,
|
||||
required_tiers=allowed_tiers
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"error": "subscription_tier_insufficient",
|
||||
"message": f"This feature requires a {' or '.join(allowed_tiers)} subscription plan",
|
||||
"current_plan": subscription_tier,
|
||||
"required_plans": allowed_tiers,
|
||||
"upgrade_url": "/app/settings/profile"
|
||||
}
|
||||
)
|
||||
|
||||
# Check user role
|
||||
user_role = current_user.get('role', 'member')
|
||||
role_access = check_role_access(user_role, allowed_roles)
|
||||
|
||||
if not role_access:
|
||||
logger.warning(
|
||||
"Combined access control: role check failed",
|
||||
user_role=user_role,
|
||||
required_roles=allowed_roles
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"error": "insufficient_permissions",
|
||||
"message": f"This action requires {' or '.join(allowed_roles)} role",
|
||||
"current_role": user_role,
|
||||
"required_roles": allowed_roles
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Combined access control passed",
|
||||
tier=subscription_tier,
|
||||
role=user_role,
|
||||
required_tiers=allowed_tiers,
|
||||
required_roles=allowed_roles
|
||||
)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# Convenience decorators for common patterns
|
||||
analytics_tier_required = require_subscription_tier(['professional', 'enterprise'])
|
||||
enterprise_tier_required = require_subscription_tier(['enterprise'])
|
||||
admin_role_required = require_user_role(['admin', 'owner'])
|
||||
owner_role_required = require_user_role(['owner'])
|
||||
|
||||
|
||||
def service_only_access(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to restrict endpoint access to service-to-service calls only
|
||||
|
||||
This decorator validates that:
|
||||
1. The request has a valid service token (type='service' in JWT)
|
||||
2. The token is from an authorized internal service
|
||||
|
||||
Usage:
|
||||
@router.delete("/tenant/{tenant_id}")
|
||||
@service_only_access
|
||||
async def delete_tenant_data(
|
||||
tenant_id: str,
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db = Depends(get_db)
|
||||
):
|
||||
# Service-only logic here
|
||||
|
||||
The decorator expects current_user to be injected via get_current_user_dep
|
||||
dependency, which should already contain the user/service context from JWT.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Get current user from kwargs (injected by get_current_user_dep)
|
||||
current_user = kwargs.get('current_user')
|
||||
|
||||
if not current_user:
|
||||
# Try to find in args
|
||||
for arg in args:
|
||||
if isinstance(arg, dict) and 'user_id' in arg:
|
||||
current_user = arg
|
||||
break
|
||||
|
||||
if not current_user:
|
||||
logger.error("Service-only access: current user not found in request context")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
# Check if this is a service token
|
||||
user_type = current_user.get('type', '')
|
||||
is_service = current_user.get('is_service', False)
|
||||
|
||||
if user_type != 'service' and not is_service:
|
||||
logger.warning(
|
||||
"Service-only access denied: not a service token",
|
||||
user_id=current_user.get('user_id'),
|
||||
user_type=user_type,
|
||||
is_service=is_service
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="This endpoint is only accessible to internal services"
|
||||
)
|
||||
|
||||
# Log successful service access
|
||||
service_name = current_user.get('service', current_user.get('user_id', 'unknown'))
|
||||
logger.info(
|
||||
"Service-only access granted",
|
||||
service=service_name,
|
||||
endpoint=func.__name__
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_verified_subscription_tier(
|
||||
allowed_tiers: List[str],
|
||||
verify_in_database: bool = False
|
||||
):
|
||||
"""
|
||||
Subscription tier enforcement with optional database verification.
|
||||
|
||||
Args:
|
||||
allowed_tiers: List of allowed subscription tiers
|
||||
verify_in_database: If True, verify against database (for critical operations)
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request = kwargs.get('request') or args[0]
|
||||
|
||||
# Get tier from gateway-injected header (from verified JWT)
|
||||
header_tier = request.headers.get("x-subscription-tier", "starter").lower()
|
||||
|
||||
if header_tier not in [t.lower() for t in allowed_tiers]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"error": "subscription_required",
|
||||
"message": f"This feature requires {', '.join(allowed_tiers)} subscription",
|
||||
"current_tier": header_tier,
|
||||
"required_tiers": allowed_tiers
|
||||
}
|
||||
)
|
||||
|
||||
# For critical operations, verify against database
|
||||
if verify_in_database:
|
||||
tenant_id = request.headers.get("x-tenant-id")
|
||||
if tenant_id:
|
||||
db_tier = await _verify_subscription_in_database(tenant_id)
|
||||
if db_tier.lower() != header_tier:
|
||||
logger.error(
|
||||
"Subscription tier mismatch detected!",
|
||||
header_tier=header_tier,
|
||||
db_tier=db_tier,
|
||||
tenant_id=tenant_id,
|
||||
user_id=request.headers.get("x-user-id")
|
||||
)
|
||||
# Use database tier (authoritative)
|
||||
if db_tier.lower() not in [t.lower() for t in allowed_tiers]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"error": "subscription_verification_failed",
|
||||
"message": "Subscription tier verification failed"
|
||||
}
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
async def _verify_subscription_in_database(tenant_id: str) -> str:
|
||||
"""
|
||||
Direct database verification of subscription tier.
|
||||
Used for critical operations as defense-in-depth.
|
||||
"""
|
||||
from shared.clients.subscription_client import SubscriptionClient
|
||||
|
||||
client = SubscriptionClient()
|
||||
subscription = await client.get_subscription(tenant_id)
|
||||
return subscription.get("plan", "starter")
|
||||
704
shared/auth/decorators.py
Executable file
704
shared/auth/decorators.py
Executable file
@@ -0,0 +1,704 @@
|
||||
# ================================================================
|
||||
# shared/auth/decorators.py - ENHANCED WITH ADMIN ROLE DECORATOR
|
||||
# ================================================================
|
||||
"""
|
||||
Enhanced authentication decorators for microservices including admin role validation.
|
||||
Designed to work with gateway authentication middleware and provide centralized
|
||||
role-based access control across all services.
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from fastapi import HTTPException, status, Request, Depends
|
||||
from fastapi.security import HTTPBearer
|
||||
from typing import Callable, Optional, Dict, Any, List
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Bearer token scheme for services that need it
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
def require_authentication(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to require authentication - trusts gateway validation
|
||||
Services behind the gateway should use this decorator
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Find request object in arguments
|
||||
request = None
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
break
|
||||
|
||||
if not request:
|
||||
request = kwargs.get('request')
|
||||
|
||||
if not request:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Request object not found"
|
||||
)
|
||||
|
||||
# Check if user context exists (set by gateway)
|
||||
if not hasattr(request.state, 'user') or not request.state.user:
|
||||
# Check headers as fallback (for direct service calls in dev)
|
||||
user_info = extract_user_from_headers(request)
|
||||
if not user_info:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
request.state.user = user_info
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def require_tenant_access(func: Callable) -> Callable:
|
||||
"""Decorator to require tenant access"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request = None
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
break
|
||||
|
||||
if not request:
|
||||
request = kwargs.get('request')
|
||||
|
||||
if not request or not hasattr(request.state, 'tenant_id'):
|
||||
# Try to extract from headers
|
||||
tenant_id = extract_tenant_from_headers(request)
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Tenant access required"
|
||||
)
|
||||
request.state.tenant_id = tenant_id
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def require_role(role: str):
|
||||
"""Decorator to require specific role"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request = None
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
break
|
||||
|
||||
if not request:
|
||||
request = kwargs.get('request')
|
||||
|
||||
user = get_current_user(request)
|
||||
user_role = user.get('role', '').lower()
|
||||
|
||||
if user_role != role.lower() and user_role != 'admin':
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"{role} role required"
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def require_admin_role(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to require admin role - simplified version for FastAPI dependencies
|
||||
|
||||
This decorator ensures only users with 'admin' role can access the endpoint.
|
||||
Can be used as a FastAPI dependency or function decorator.
|
||||
|
||||
Usage as dependency:
|
||||
@router.delete("/admin/users/{user_id}")
|
||||
async def delete_user(
|
||||
user_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
):
|
||||
# Admin-only logic here
|
||||
|
||||
Usage as decorator:
|
||||
@require_admin_role
|
||||
@router.delete("/admin/users/{user_id}")
|
||||
async def delete_user(...):
|
||||
# Admin-only logic here
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Find request object in arguments
|
||||
request = None
|
||||
current_user = None
|
||||
|
||||
# Extract request and current_user from arguments
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
elif isinstance(arg, dict) and 'user_id' in arg:
|
||||
current_user = arg
|
||||
|
||||
# Check kwargs for request and current_user
|
||||
if not request:
|
||||
request = kwargs.get('request')
|
||||
if not current_user:
|
||||
current_user = kwargs.get('current_user')
|
||||
|
||||
# If we still don't have current_user, try to get it from request
|
||||
if not current_user and request:
|
||||
current_user = get_current_user(request)
|
||||
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
# Check if user has admin role
|
||||
user_role = current_user.get('role', '').lower()
|
||||
|
||||
if user_role != 'admin':
|
||||
logger.warning("Non-admin user attempted admin operation",
|
||||
user_id=current_user.get('user_id'),
|
||||
role=user_role)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin role required"
|
||||
)
|
||||
|
||||
logger.info("Admin operation authorized",
|
||||
user_id=current_user.get('user_id'),
|
||||
endpoint=func.__name__)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def require_roles(allowed_roles: List[str]):
|
||||
"""
|
||||
Decorator to require one of multiple roles
|
||||
|
||||
Args:
|
||||
allowed_roles: List of roles that are allowed to access the endpoint
|
||||
|
||||
Usage:
|
||||
@require_roles(['admin', 'manager'])
|
||||
@router.post("/sensitive-operation")
|
||||
async def sensitive_operation(...):
|
||||
# Only admins and managers can access
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request = None
|
||||
current_user = None
|
||||
|
||||
# Extract request and current_user from arguments
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
elif isinstance(arg, dict) and 'user_id' in arg:
|
||||
current_user = arg
|
||||
|
||||
# Check kwargs
|
||||
if not request:
|
||||
request = kwargs.get('request')
|
||||
if not current_user:
|
||||
current_user = kwargs.get('current_user')
|
||||
|
||||
# Get user from request if not provided
|
||||
if not current_user and request:
|
||||
current_user = get_current_user(request)
|
||||
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
# Check if user has one of the allowed roles
|
||||
user_role = current_user.get('role', '').lower()
|
||||
allowed_roles_lower = [role.lower() for role in allowed_roles]
|
||||
|
||||
if user_role not in allowed_roles_lower:
|
||||
logger.warning("Unauthorized role attempted restricted operation",
|
||||
user_id=current_user.get('user_id'),
|
||||
user_role=user_role,
|
||||
allowed_roles=allowed_roles)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"One of these roles required: {', '.join(allowed_roles)}"
|
||||
)
|
||||
|
||||
logger.info("Role-based operation authorized",
|
||||
user_id=current_user.get('user_id'),
|
||||
user_role=user_role,
|
||||
endpoint=func.__name__)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def require_tenant_admin(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to require admin role within a specific tenant context
|
||||
|
||||
This checks that the user is an admin AND has access to the tenant
|
||||
being operated on. Useful for tenant-scoped admin operations.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request = None
|
||||
current_user = None
|
||||
|
||||
# Extract request and current_user from arguments
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
elif isinstance(arg, dict) and 'user_id' in arg:
|
||||
current_user = arg
|
||||
|
||||
if not request:
|
||||
request = kwargs.get('request')
|
||||
if not current_user:
|
||||
current_user = kwargs.get('current_user')
|
||||
|
||||
if not current_user and request:
|
||||
current_user = get_current_user(request)
|
||||
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
# Check admin role first
|
||||
user_role = current_user.get('role', '').lower()
|
||||
if user_role != 'admin':
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin role required"
|
||||
)
|
||||
|
||||
# Check tenant access
|
||||
tenant_id = get_current_tenant_id(request) if request else None
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Tenant context required"
|
||||
)
|
||||
|
||||
# Additional tenant admin validation could go here
|
||||
# For now, we trust that admin users have access to operate on any tenant
|
||||
|
||||
logger.info("Tenant admin operation authorized",
|
||||
user_id=current_user.get('user_id'),
|
||||
tenant_id=tenant_id,
|
||||
endpoint=func.__name__)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def get_current_user(request: Request) -> Dict[str, Any]:
|
||||
"""Get current user from request state or headers"""
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
return request.state.user
|
||||
|
||||
# Fallback to headers (for dev/testing)
|
||||
user_info = extract_user_from_headers(request)
|
||||
if not user_info:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not authenticated"
|
||||
)
|
||||
|
||||
return user_info
|
||||
|
||||
def get_current_tenant_id(request: Request) -> Optional[str]:
|
||||
"""Get current tenant ID from request state or headers"""
|
||||
if hasattr(request.state, 'tenant_id'):
|
||||
return request.state.tenant_id
|
||||
|
||||
# Fallback to headers
|
||||
return extract_tenant_from_headers(request)
|
||||
|
||||
def extract_user_from_headers(request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""Extract user information from forwarded headers (gateway sets these)"""
|
||||
user_id = request.headers.get("x-user-id")
|
||||
logger.info(f"🔍 Extracting user from headers",
|
||||
user_id=user_id,
|
||||
has_user_id=bool(user_id),
|
||||
path=request.url.path)
|
||||
|
||||
if not user_id:
|
||||
logger.warning(f"❌ No x-user-id header found", path=request.url.path)
|
||||
return None
|
||||
|
||||
user_context = {
|
||||
"user_id": user_id,
|
||||
"email": request.headers.get("x-user-email", ""),
|
||||
"role": request.headers.get("x-user-role", "user"),
|
||||
"tenant_id": request.headers.get("x-tenant-id"),
|
||||
"permissions": request.headers.get("X-User-Permissions", "").split(",") if request.headers.get("X-User-Permissions") else [],
|
||||
"full_name": request.headers.get("x-user-full-name", ""),
|
||||
"subscription_tier": request.headers.get("x-subscription-tier", ""),
|
||||
"is_demo": request.headers.get("x-is-demo", "").lower() == "true",
|
||||
"demo_session_id": request.headers.get("x-demo-session-id", ""),
|
||||
"demo_account_type": request.headers.get("x-demo-account-type", "")
|
||||
}
|
||||
|
||||
logger.info(f"✅ User context extracted from headers",
|
||||
user_context=user_context,
|
||||
path=request.url.path)
|
||||
|
||||
# ✅ ADD THIS: Handle service tokens properly
|
||||
user_type = request.headers.get("x-user-type", "")
|
||||
service_name = request.headers.get("x-service-name", "")
|
||||
|
||||
if user_type == "service" or service_name:
|
||||
user_context.update({
|
||||
"type": "service",
|
||||
"service": service_name,
|
||||
"role": "admin", # Service tokens always have admin role
|
||||
"is_service": True
|
||||
})
|
||||
|
||||
return user_context
|
||||
|
||||
def extract_tenant_from_headers(request: Request) -> Optional[str]:
|
||||
"""Extract tenant ID from headers"""
|
||||
return request.headers.get("x-tenant-id")
|
||||
|
||||
def extract_user_from_jwt(auth_header: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract user information from JWT token
|
||||
This is a fallback for when gateway doesn't inject x-user-* headers
|
||||
"""
|
||||
try:
|
||||
from jose import jwt
|
||||
from shared.config.base import is_internal_service
|
||||
|
||||
# Remove "Bearer " prefix
|
||||
token = auth_header.replace("Bearer ", "").strip()
|
||||
|
||||
# Decode without verification (we trust tokens from gateway)
|
||||
# In production, you'd verify with the secret key
|
||||
payload = jwt.decode(token, key="dummy", options={"verify_signature": False})
|
||||
|
||||
logger.debug("JWT payload decoded", payload_keys=list(payload.keys()))
|
||||
|
||||
# Extract user information from JWT payload
|
||||
user_id = payload.get("sub") or payload.get("user_id") or payload.get("service")
|
||||
|
||||
if not user_id:
|
||||
logger.warning("No user_id found in JWT payload", payload=payload)
|
||||
return None
|
||||
|
||||
# Check if this is a service token
|
||||
token_type = payload.get("type", "")
|
||||
service_name = payload.get("service", "")
|
||||
|
||||
if token_type == "service" or is_internal_service(user_id) or is_internal_service(service_name):
|
||||
# This is a service token
|
||||
service_identifier = service_name or user_id
|
||||
user_context = {
|
||||
"user_id": service_identifier,
|
||||
"type": "service",
|
||||
"service": service_identifier,
|
||||
"role": "admin", # Services get admin privileges
|
||||
"is_service": True,
|
||||
"permissions": ["read", "write", "admin"],
|
||||
"email": f"{service_identifier}@internal.service",
|
||||
"full_name": f"{service_identifier.replace('-', ' ').title()}"
|
||||
}
|
||||
logger.info("Service authenticated via JWT", service=service_identifier)
|
||||
else:
|
||||
# This is a user token
|
||||
user_context = {
|
||||
"user_id": user_id,
|
||||
"type": "user",
|
||||
"email": payload.get("email", ""),
|
||||
"role": payload.get("role", "user"),
|
||||
"tenant_id": payload.get("tenant_id"),
|
||||
"permissions": payload.get("permissions", []),
|
||||
"full_name": payload.get("full_name", ""),
|
||||
"subscription_tier": payload.get("subscription_tier", ""),
|
||||
"is_service": False
|
||||
}
|
||||
logger.info("User authenticated via JWT", user_id=user_id)
|
||||
|
||||
return user_context
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract user from JWT", error=str(e), error_type=type(e).__name__)
|
||||
return None
|
||||
|
||||
# ================================================================
|
||||
# FASTAPI DEPENDENCY FUNCTIONS
|
||||
# ================================================================
|
||||
|
||||
async def get_current_user_dep(request: Request) -> Dict[str, Any]:
|
||||
"""FastAPI dependency to get current user - ENHANCED with JWT fallback for services"""
|
||||
try:
|
||||
# Enhanced logging for debugging
|
||||
logger.info(
|
||||
"🔐 Authentication attempt",
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
has_auth_header=bool(request.headers.get("authorization")),
|
||||
has_x_user_id=bool(request.headers.get("x-user-id")),
|
||||
has_x_is_demo=bool(request.headers.get("x-is-demo")),
|
||||
has_x_demo_session_id=bool(request.headers.get("x-demo-session-id")),
|
||||
x_user_id=request.headers.get("x-user-id", "MISSING"),
|
||||
x_is_demo=request.headers.get("x-is-demo", "MISSING"),
|
||||
x_demo_session_id=request.headers.get("x-demo-session-id", "MISSING"),
|
||||
client_ip=request.client.host if request.client else "unknown"
|
||||
)
|
||||
|
||||
# Try to get user from headers first (preferred method)
|
||||
user = None
|
||||
try:
|
||||
user = get_current_user(request)
|
||||
except HTTPException:
|
||||
# If headers are missing, try JWT token as fallback
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
user = extract_user_from_jwt(auth_header)
|
||||
if user:
|
||||
logger.info(
|
||||
"User authenticated via JWT fallback",
|
||||
user_id=user.get("user_id"),
|
||||
user_type=user.get("type", "user"),
|
||||
is_service=user.get("type") == "service",
|
||||
path=request.url.path
|
||||
)
|
||||
|
||||
# If still no user, raise original exception
|
||||
if not user:
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
"User authenticated successfully",
|
||||
user_id=user.get("user_id"),
|
||||
user_type=user.get("type", "user"),
|
||||
is_service=user.get("type") == "service",
|
||||
role=user.get("role"),
|
||||
path=request.url.path
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
except HTTPException as e:
|
||||
logger.warning(
|
||||
"Authentication failed - 401",
|
||||
path=request.url.path,
|
||||
status_code=e.status_code,
|
||||
detail=e.detail,
|
||||
has_x_user_id=bool(request.headers.get("x-user-id")),
|
||||
has_auth_header=bool(request.headers.get("authorization")),
|
||||
x_user_type=request.headers.get("x-user-type", "none"),
|
||||
x_service_name=request.headers.get("x-service-name", "none"),
|
||||
client_ip=request.client.host if request.client else "unknown"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_current_tenant_id_dep(request: Request) -> Optional[str]:
|
||||
"""FastAPI dependency to get current tenant ID"""
|
||||
return get_current_tenant_id(request)
|
||||
|
||||
async def require_admin_role_dep(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
FastAPI dependency that requires admin role
|
||||
|
||||
Usage:
|
||||
@router.delete("/admin/users/{user_id}")
|
||||
async def delete_user(
|
||||
user_id: str,
|
||||
admin_user: Dict[str, Any] = Depends(require_admin_role_dep)
|
||||
):
|
||||
# admin_user is guaranteed to have admin role
|
||||
"""
|
||||
|
||||
user_role = current_user.get('role', '').lower()
|
||||
|
||||
if user_role != 'admin':
|
||||
logger.warning("Non-admin user attempted admin operation",
|
||||
user_id=current_user.get('user_id'),
|
||||
role=user_role)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin role required"
|
||||
)
|
||||
|
||||
logger.info("Admin operation authorized via dependency",
|
||||
user_id=current_user.get('user_id'))
|
||||
|
||||
return current_user
|
||||
|
||||
async def require_roles_dep(allowed_roles: List[str]):
|
||||
"""
|
||||
FastAPI dependency factory that requires one of multiple roles
|
||||
|
||||
Usage:
|
||||
require_manager_or_admin = require_roles_dep(['admin', 'manager'])
|
||||
|
||||
@router.post("/sensitive-operation")
|
||||
async def sensitive_operation(
|
||||
user: Dict[str, Any] = Depends(require_manager_or_admin)
|
||||
):
|
||||
# Only admins and managers can access
|
||||
"""
|
||||
|
||||
async def check_roles(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep)
|
||||
) -> Dict[str, Any]:
|
||||
user_role = current_user.get('role', '').lower()
|
||||
allowed_roles_lower = [role.lower() for role in allowed_roles]
|
||||
|
||||
if user_role not in allowed_roles_lower:
|
||||
logger.warning("Unauthorized role attempted restricted operation",
|
||||
user_id=current_user.get('user_id'),
|
||||
user_role=user_role,
|
||||
allowed_roles=allowed_roles)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"One of these roles required: {', '.join(allowed_roles)}"
|
||||
)
|
||||
|
||||
logger.info("Role-based operation authorized via dependency",
|
||||
user_id=current_user.get('user_id'),
|
||||
user_role=user_role)
|
||||
|
||||
return current_user
|
||||
|
||||
return check_roles
|
||||
|
||||
# ================================================================
|
||||
# UTILITY FUNCTIONS FOR ROLE CHECKING
|
||||
# ================================================================
|
||||
|
||||
def is_admin_user(user: Dict[str, Any]) -> bool:
|
||||
"""Check if user has admin role"""
|
||||
return user.get('role', '').lower() == 'admin'
|
||||
|
||||
def is_user_in_roles(user: Dict[str, Any], allowed_roles: List[str]) -> bool:
|
||||
"""Check if user has one of the allowed roles"""
|
||||
user_role = user.get('role', '').lower()
|
||||
allowed_roles_lower = [role.lower() for role in allowed_roles]
|
||||
return user_role in allowed_roles_lower
|
||||
|
||||
def get_user_permissions(user: Dict[str, Any]) -> List[str]:
|
||||
"""Get user permissions list"""
|
||||
return user.get('permissions', [])
|
||||
|
||||
def has_permission(user: Dict[str, Any], permission: str) -> bool:
|
||||
"""Check if user has specific permission"""
|
||||
permissions = get_user_permissions(user)
|
||||
return permission in permissions
|
||||
|
||||
# ================================================================
|
||||
# ADVANCED ROLE DECORATORS
|
||||
# ================================================================
|
||||
|
||||
def require_permission(permission: str):
|
||||
"""
|
||||
Decorator to require specific permission
|
||||
|
||||
Usage:
|
||||
@require_permission('delete_users')
|
||||
@router.delete("/users/{user_id}")
|
||||
async def delete_user(...):
|
||||
# Only users with 'delete_users' permission can access
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
current_user = None
|
||||
|
||||
# Extract current_user from arguments
|
||||
for arg in args:
|
||||
if isinstance(arg, dict) and 'user_id' in arg:
|
||||
current_user = arg
|
||||
break
|
||||
|
||||
if not current_user:
|
||||
current_user = kwargs.get('current_user')
|
||||
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
# Check permission
|
||||
if not has_permission(current_user, permission):
|
||||
# Admins bypass permission checks
|
||||
if not is_admin_user(current_user):
|
||||
logger.warning("User lacks required permission",
|
||||
user_id=current_user.get('user_id'),
|
||||
required_permission=permission,
|
||||
user_permissions=get_user_permissions(current_user))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Permission '{permission}' required"
|
||||
)
|
||||
|
||||
logger.info("Permission-based operation authorized",
|
||||
user_id=current_user.get('user_id'),
|
||||
permission=permission)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
# Export all decorators and functions
|
||||
__all__ = [
|
||||
# Main decorators
|
||||
'require_authentication',
|
||||
'require_tenant_access',
|
||||
'require_role',
|
||||
'require_admin_role',
|
||||
'require_roles',
|
||||
'require_tenant_admin',
|
||||
'require_permission',
|
||||
|
||||
# FastAPI dependencies
|
||||
'get_current_user_dep',
|
||||
'get_current_tenant_id_dep',
|
||||
'require_admin_role_dep',
|
||||
'require_roles_dep',
|
||||
|
||||
# Utility functions
|
||||
'get_current_user',
|
||||
'get_current_tenant_id',
|
||||
'extract_user_from_headers',
|
||||
'extract_user_from_jwt',
|
||||
'extract_tenant_from_headers',
|
||||
'is_admin_user',
|
||||
'is_user_in_roles',
|
||||
'get_user_permissions',
|
||||
'has_permission'
|
||||
]
|
||||
292
shared/auth/jwt_handler.py
Executable file
292
shared/auth/jwt_handler.py
Executable file
@@ -0,0 +1,292 @@
|
||||
# shared/auth/jwt_handler.py
|
||||
"""
|
||||
Enhanced JWT Handler with proper token structure and validation
|
||||
FIXED VERSION - Consistent token format between all services
|
||||
"""
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Dict, Any
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class JWTHandler:
|
||||
"""Enhanced JWT token handling with consistent format"""
|
||||
|
||||
def __init__(self, secret_key: str, algorithm: str = "HS256"):
|
||||
self.secret_key = secret_key
|
||||
self.algorithm = algorithm
|
||||
|
||||
def create_access_token_from_payload(self, payload: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Create JWT ACCESS token from complete payload
|
||||
✅ FIXED: Only creates access tokens with access token structure
|
||||
"""
|
||||
try:
|
||||
# Ensure this is marked as an access token
|
||||
payload["type"] = "access"
|
||||
|
||||
encoded_jwt = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||||
logger.debug(f"Created access token with payload keys: {list(payload.keys())}")
|
||||
return encoded_jwt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Access token creation failed: {e}")
|
||||
raise ValueError(f"Failed to encode access token: {str(e)}")
|
||||
|
||||
def create_refresh_token_from_payload(self, payload: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Create JWT REFRESH token from complete payload
|
||||
✅ FIXED: Only creates refresh tokens with refresh token structure
|
||||
"""
|
||||
try:
|
||||
# Ensure this is marked as a refresh token
|
||||
payload["type"] = "refresh"
|
||||
|
||||
encoded_jwt = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||||
logger.debug(f"Created refresh token with payload keys: {list(payload.keys())}")
|
||||
return encoded_jwt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Refresh token creation failed: {e}")
|
||||
raise ValueError(f"Failed to encode refresh token: {str(e)}")
|
||||
|
||||
def create_access_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""
|
||||
Create JWT access token with STANDARD structure (legacy method)
|
||||
✅ FIXED: Consistent payload format for access tokens
|
||||
"""
|
||||
to_encode = {
|
||||
"sub": user_data["user_id"],
|
||||
"user_id": user_data["user_id"],
|
||||
"email": user_data["email"],
|
||||
"type": "access"
|
||||
}
|
||||
|
||||
# Add optional fields if present
|
||||
if "full_name" in user_data:
|
||||
to_encode["full_name"] = user_data["full_name"]
|
||||
if "is_verified" in user_data:
|
||||
to_encode["is_verified"] = user_data["is_verified"]
|
||||
if "is_active" in user_data:
|
||||
to_encode["is_active"] = user_data["is_active"]
|
||||
|
||||
# Set expiration
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"iss": "bakery-auth"
|
||||
})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
logger.debug(f"Created access token for user {user_data['email']}")
|
||||
return encoded_jwt
|
||||
|
||||
def create_refresh_token(self, user_data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""
|
||||
Create JWT refresh token with MINIMAL payload (legacy method)
|
||||
✅ FIXED: Consistent refresh token structure, different from access
|
||||
"""
|
||||
to_encode = {
|
||||
"sub": user_data["user_id"],
|
||||
"user_id": user_data["user_id"],
|
||||
"type": "refresh"
|
||||
}
|
||||
|
||||
# Add unique identifier to prevent duplicates
|
||||
if "jti" in user_data:
|
||||
to_encode["jti"] = user_data["jti"]
|
||||
else:
|
||||
import uuid
|
||||
to_encode["jti"] = str(uuid.uuid4())
|
||||
|
||||
# Include email only if available (optional for refresh tokens)
|
||||
if "email" in user_data and user_data["email"]:
|
||||
to_encode["email"] = user_data["email"]
|
||||
|
||||
# Set expiration
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"iss": "bakery-auth"
|
||||
})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
logger.debug(f"Created refresh token for user {user_data['user_id']}")
|
||||
return encoded_jwt
|
||||
|
||||
def create_service_token(
|
||||
self,
|
||||
service_name: str,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Create JWT SERVICE token for inter-service communication
|
||||
✅ UNIFIED: Single source of truth for all service token creation
|
||||
✅ ENHANCED: Supports tenant context for tenant-scoped operations
|
||||
|
||||
Args:
|
||||
service_name: Name of the service (e.g., 'auth-service', 'demo-session')
|
||||
expires_delta: Optional expiration time (defaults to 1 hour for inter-service calls)
|
||||
tenant_id: Optional tenant ID for tenant-scoped service operations
|
||||
|
||||
Returns:
|
||||
Encoded JWT service token
|
||||
"""
|
||||
to_encode = {
|
||||
"sub": service_name,
|
||||
"user_id": f"{service_name}-service",
|
||||
"email": f"{service_name}-service@internal",
|
||||
"service": service_name,
|
||||
"type": "service",
|
||||
"role": "admin", # Services have admin privileges
|
||||
"is_service": True,
|
||||
"full_name": f"{service_name.title()} Service",
|
||||
"is_verified": True,
|
||||
"is_active": True
|
||||
}
|
||||
|
||||
# Include tenant context when provided for tenant-scoped operations
|
||||
if tenant_id:
|
||||
to_encode["tenant_id"] = tenant_id
|
||||
|
||||
# Set expiration (default to 1 hour for inter-service calls)
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=1) # 1 hour default
|
||||
|
||||
to_encode.update({
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"iss": "bakery-auth"
|
||||
})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
logger.debug(f"Created service token for service {service_name}", tenant_id=tenant_id)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verify and decode JWT token
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
|
||||
# Check if token is expired
|
||||
exp_timestamp = payload.get("exp")
|
||||
if exp_timestamp:
|
||||
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
|
||||
if datetime.now(timezone.utc) > exp_datetime:
|
||||
logger.debug("Token is expired")
|
||||
return None
|
||||
|
||||
logger.debug(f"Token verified successfully, type: {payload.get('type', 'unknown')}")
|
||||
return payload
|
||||
|
||||
except JWTError as e:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Token verification error: {e}")
|
||||
return None
|
||||
|
||||
def decode_token_no_verify(self, token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Decode JWT token without verification (for inspection purposes)
|
||||
"""
|
||||
try:
|
||||
# Decode without verification - need to provide key but disable verification
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm], options={"verify_signature": False})
|
||||
return payload
|
||||
except Exception as e:
|
||||
logger.error(f"Token decoding failed: {e}")
|
||||
raise ValueError("Invalid token format")
|
||||
|
||||
def get_token_type(self, token: str) -> Optional[str]:
|
||||
"""
|
||||
Get the type of token (access or refresh) without full verification
|
||||
"""
|
||||
try:
|
||||
payload = self.decode_token_no_verify(token)
|
||||
return payload.get("type")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def is_token_expired(self, token: str) -> bool:
|
||||
"""
|
||||
Check if token is expired without full verification
|
||||
"""
|
||||
try:
|
||||
payload = self.decode_token_no_verify(token)
|
||||
exp_timestamp = payload.get("exp")
|
||||
if exp_timestamp:
|
||||
exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
|
||||
return datetime.now(timezone.utc) > exp_datetime
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
def extract_user_id(self, token: str) -> Optional[str]:
|
||||
"""
|
||||
Extract user ID from token without full verification
|
||||
Useful for quick user identification
|
||||
"""
|
||||
try:
|
||||
payload = self.decode_token_no_verify(token)
|
||||
if payload:
|
||||
return payload.get("user_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract user ID from token: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def get_token_info(self, token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get comprehensive token information for debugging
|
||||
"""
|
||||
info = {
|
||||
"valid": False,
|
||||
"expired": True,
|
||||
"user_id": None,
|
||||
"email": None,
|
||||
"type": None,
|
||||
"exp": None,
|
||||
"iat": None
|
||||
}
|
||||
|
||||
try:
|
||||
# Try unsafe decode first
|
||||
payload = self.decode_token_no_verify(token)
|
||||
if payload:
|
||||
info.update({
|
||||
"user_id": payload.get("user_id"),
|
||||
"email": payload.get("email"),
|
||||
"type": payload.get("type"),
|
||||
"exp": payload.get("exp"),
|
||||
"iat": payload.get("iat"),
|
||||
"expired": self.is_token_expired(token)
|
||||
})
|
||||
|
||||
# Try full verification
|
||||
verified_payload = self.verify_token(token)
|
||||
info["valid"] = verified_payload is not None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get token info: {e}")
|
||||
|
||||
return info
|
||||
529
shared/auth/tenant_access.py
Executable file
529
shared/auth/tenant_access.py
Executable file
@@ -0,0 +1,529 @@
|
||||
# ================================================================
|
||||
# shared/auth/tenant_access.py - Complete Implementation
|
||||
# ================================================================
|
||||
"""
|
||||
Tenant access control utilities for microservices
|
||||
Provides both gateway-level and service-level tenant access verification
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any, Optional
|
||||
import httpx
|
||||
import structlog
|
||||
from fastapi import HTTPException, Depends
|
||||
import asyncio
|
||||
|
||||
# Import FastAPI dependencies
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
|
||||
# Import settings (adjust import path based on your project structure)
|
||||
try:
|
||||
from app.core.config import settings
|
||||
except ImportError:
|
||||
try:
|
||||
from core.config import settings
|
||||
except ImportError:
|
||||
# Fallback for different project structures
|
||||
import os
|
||||
class Settings:
|
||||
TENANT_SERVICE_URL = os.getenv("TENANT_SERVICE_URL", "http://tenant-service:8000")
|
||||
settings = Settings()
|
||||
|
||||
# Setup logging
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class TenantAccessManager:
|
||||
"""
|
||||
Centralized tenant access management for both gateway and service level
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client=None):
|
||||
"""
|
||||
Initialize tenant access manager
|
||||
|
||||
Args:
|
||||
redis_client: Optional Redis client for caching
|
||||
"""
|
||||
self.redis_client = redis_client
|
||||
|
||||
async def verify_basic_tenant_access(self, user_id: str, tenant_id: str) -> bool:
|
||||
"""
|
||||
Gateway-level: Basic tenant access verification with caching
|
||||
|
||||
Args:
|
||||
user_id: User ID to verify
|
||||
tenant_id: Tenant ID to check access for
|
||||
|
||||
Returns:
|
||||
bool: True if user has access to tenant
|
||||
"""
|
||||
# Check cache first (5-minute TTL)
|
||||
cache_key = f"tenant_access:{user_id}:{tenant_id}"
|
||||
if self.redis_client:
|
||||
try:
|
||||
cached_result = await self.redis_client.get(cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result.decode() == "true" if isinstance(cached_result, bytes) else cached_result == "true"
|
||||
except Exception as cache_error:
|
||||
logger.warning(f"Cache lookup failed: {cache_error}")
|
||||
|
||||
# Verify with tenant service
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=2.0) as client: # Short timeout for gateway
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/access/{user_id}"
|
||||
)
|
||||
|
||||
has_access = response.status_code == 200
|
||||
|
||||
# If direct access check fails, check hierarchical access
|
||||
if not has_access:
|
||||
hierarchical_access = await self._check_hierarchical_access(user_id, tenant_id)
|
||||
has_access = hierarchical_access
|
||||
|
||||
# Cache result (5 minutes)
|
||||
if self.redis_client:
|
||||
try:
|
||||
await self.redis_client.setex(cache_key, 300, "true" if has_access else "false")
|
||||
except Exception as cache_error:
|
||||
logger.warning(f"Cache set failed: {cache_error}")
|
||||
|
||||
logger.debug(f"Tenant access check",
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
has_access=has_access)
|
||||
|
||||
return has_access
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timeout verifying tenant access: user={user_id}, tenant={tenant_id}")
|
||||
# Fail open for availability (let service handle detailed check)
|
||||
return True
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error verifying tenant access: {e}")
|
||||
# Fail open for availability
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Gateway tenant access verification failed: {e}")
|
||||
# Fail open for availability (let service handle detailed check)
|
||||
return True
|
||||
|
||||
async def _check_hierarchical_access(self, user_id: str, tenant_id: str) -> bool:
|
||||
"""
|
||||
Check if user has hierarchical access (parent tenant access to child)
|
||||
|
||||
Args:
|
||||
user_id: User ID to verify
|
||||
tenant_id: Target tenant ID to check access for
|
||||
|
||||
Returns:
|
||||
bool: True if user has hierarchical access to the tenant
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/hierarchy"
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
hierarchy_data = response.json()
|
||||
parent_tenant_id = hierarchy_data.get("parent_tenant_id")
|
||||
|
||||
# If this is a child tenant, check if user has access to parent
|
||||
if parent_tenant_id:
|
||||
# Check if user has access to parent tenant
|
||||
parent_access = await self._check_parent_access(user_id, parent_tenant_id)
|
||||
if parent_access:
|
||||
# For aggregated data only, allow parent access to child
|
||||
# Detailed child data requires direct access
|
||||
user_role = await self.get_user_role_in_tenant(user_id, parent_tenant_id)
|
||||
if user_role in ["owner", "admin", "network_admin"]:
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check hierarchical access: {e}")
|
||||
return False
|
||||
|
||||
async def _check_parent_access(self, user_id: str, parent_tenant_id: str) -> bool:
|
||||
"""
|
||||
Check if user has access to parent tenant (owner, admin, or network_admin role)
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
parent_tenant_id: Parent tenant ID
|
||||
|
||||
Returns:
|
||||
bool: True if user has access to parent tenant
|
||||
"""
|
||||
user_role = await self.get_user_role_in_tenant(user_id, parent_tenant_id)
|
||||
return user_role in ["owner", "admin", "network_admin"]
|
||||
|
||||
async def verify_hierarchical_access(self, user_id: str, tenant_id: str) -> dict:
|
||||
"""
|
||||
Verify hierarchical access and return access type and permissions
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
tenant_id: Target tenant ID
|
||||
|
||||
Returns:
|
||||
dict: Access information including access_type, can_view_children, etc.
|
||||
"""
|
||||
# First check direct access
|
||||
direct_access = await self._check_direct_access(user_id, tenant_id)
|
||||
|
||||
if direct_access:
|
||||
return {
|
||||
"access_type": "direct",
|
||||
"has_access": True,
|
||||
"can_view_children": False,
|
||||
"tenant_id": tenant_id
|
||||
}
|
||||
|
||||
# Check if this is a child tenant and user has parent access
|
||||
hierarchy_info = await self._get_tenant_hierarchy(tenant_id)
|
||||
|
||||
if hierarchy_info and hierarchy_info.get("parent_tenant_id"):
|
||||
parent_tenant_id = hierarchy_info["parent_tenant_id"]
|
||||
parent_access = await self._check_parent_access(user_id, parent_tenant_id)
|
||||
|
||||
if parent_access:
|
||||
user_role = await self.get_user_role_in_tenant(user_id, parent_tenant_id)
|
||||
|
||||
# Network admins have full access across entire hierarchy
|
||||
if user_role == "network_admin":
|
||||
return {
|
||||
"access_type": "hierarchical",
|
||||
"has_access": True,
|
||||
"tenant_id": tenant_id,
|
||||
"parent_tenant_id": parent_tenant_id,
|
||||
"is_network_admin": True,
|
||||
"can_view_children": True
|
||||
}
|
||||
# Regular admins have read-only access to children aggregated data
|
||||
elif user_role in ["owner", "admin"]:
|
||||
return {
|
||||
"access_type": "hierarchical",
|
||||
"has_access": True,
|
||||
"tenant_id": tenant_id,
|
||||
"parent_tenant_id": parent_tenant_id,
|
||||
"is_network_admin": False,
|
||||
"can_view_children": True # Can view aggregated data, not detailed
|
||||
}
|
||||
|
||||
return {
|
||||
"access_type": "none",
|
||||
"has_access": False,
|
||||
"tenant_id": tenant_id,
|
||||
"can_view_children": False
|
||||
}
|
||||
|
||||
async def _check_direct_access(self, user_id: str, tenant_id: str) -> bool:
|
||||
"""
|
||||
Check direct access to tenant (without hierarchy)
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=2.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/access/{user_id}"
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check direct access: {e}")
|
||||
return False
|
||||
|
||||
async def _get_tenant_hierarchy(self, tenant_id: str) -> dict:
|
||||
"""
|
||||
Get tenant hierarchy information
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
|
||||
Returns:
|
||||
dict: Hierarchy information
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/hierarchy"
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get tenant hierarchy: {e}")
|
||||
return {}
|
||||
|
||||
async def get_accessible_tenants_hierarchy(self, user_id: str) -> list:
|
||||
"""
|
||||
Get all tenants a user has access to, organized in hierarchy
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
list: List of tenants with hierarchy structure
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/users/{user_id}/hierarchy"
|
||||
)
|
||||
if response.status_code == 200:
|
||||
tenants = response.json()
|
||||
logger.debug(f"Retrieved user tenants with hierarchy",
|
||||
user_id=user_id,
|
||||
tenant_count=len(tenants))
|
||||
return tenants
|
||||
else:
|
||||
logger.warning(f"Failed to get user tenants hierarchy: {response.status_code}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user tenants hierarchy: {e}")
|
||||
return []
|
||||
|
||||
async def get_user_role_in_tenant(self, user_id: str, tenant_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get user's role within a specific tenant
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
tenant_id: Tenant ID
|
||||
|
||||
Returns:
|
||||
Optional[str]: User's role in tenant (owner, admin, manager, user, network_admin) or None
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/members/{user_id}"
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
role = data.get("role")
|
||||
logger.debug(f"User role in tenant",
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
role=role)
|
||||
return role
|
||||
elif response.status_code == 404:
|
||||
logger.debug(f"User not found in tenant",
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id)
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"Unexpected response getting user role: {response.status_code}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user role in tenant: {e}")
|
||||
return None
|
||||
|
||||
async def verify_resource_permission(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
resource: str,
|
||||
action: str
|
||||
) -> bool:
|
||||
"""
|
||||
Fine-grained resource permission check (used by services)
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
tenant_id: Tenant ID
|
||||
resource: Resource type (sales, training, forecasts, etc.)
|
||||
action: Action being performed (read, write, delete, etc.)
|
||||
|
||||
Returns:
|
||||
bool: True if user has permission
|
||||
"""
|
||||
user_role = await self.get_user_role_in_tenant(user_id, tenant_id)
|
||||
|
||||
if not user_role:
|
||||
return False
|
||||
|
||||
# Role-based permission matrix
|
||||
permissions = {
|
||||
"owner": ["*"], # Owners can do everything
|
||||
"admin": ["read", "write", "delete", "manage"],
|
||||
"manager": ["read", "write"],
|
||||
"user": ["read"]
|
||||
}
|
||||
|
||||
allowed_actions = permissions.get(user_role, [])
|
||||
has_permission = "*" in allowed_actions or action in allowed_actions
|
||||
|
||||
logger.debug(f"Resource permission check",
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
resource=resource,
|
||||
action=action,
|
||||
user_role=user_role,
|
||||
has_permission=has_permission)
|
||||
|
||||
return has_permission
|
||||
|
||||
async def get_user_tenants(self, user_id: str) -> list:
|
||||
"""
|
||||
Get all tenants a user has access to
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
list: List of tenant dictionaries
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/users/{user_id}"
|
||||
)
|
||||
if response.status_code == 200:
|
||||
tenants = response.json()
|
||||
logger.debug(f"Retrieved user tenants",
|
||||
user_id=user_id,
|
||||
tenant_count=len(tenants))
|
||||
return tenants
|
||||
else:
|
||||
logger.warning(f"Failed to get user tenants: {response.status_code}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user tenants: {e}")
|
||||
return []
|
||||
|
||||
# Global instance for easy import
|
||||
tenant_access_manager = TenantAccessManager()
|
||||
|
||||
|
||||
|
||||
# ================================================================
|
||||
# FASTAPI DEPENDENCIES
|
||||
# ================================================================
|
||||
|
||||
async def verify_tenant_access_dep(
|
||||
tenant_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep)
|
||||
) -> str:
|
||||
"""
|
||||
FastAPI dependency to verify tenant access and return tenant_id
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID from path parameter
|
||||
current_user: Current user from auth dependency
|
||||
|
||||
Returns:
|
||||
str: Validated tenant_id
|
||||
|
||||
Raises:
|
||||
HTTPException: If user doesn't have access to tenant
|
||||
"""
|
||||
has_access = await tenant_access_manager.verify_basic_tenant_access(current_user["user_id"], tenant_id)
|
||||
if not has_access:
|
||||
logger.warning(f"Access denied to tenant",
|
||||
user_id=current_user["user_id"],
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"User {current_user['user_id']} does not have access to tenant {tenant_id}"
|
||||
)
|
||||
|
||||
logger.debug(f"Tenant access verified",
|
||||
user_id=current_user["user_id"],
|
||||
tenant_id=tenant_id)
|
||||
|
||||
return tenant_id
|
||||
|
||||
async def verify_tenant_permission_dep(
|
||||
tenant_id: str,
|
||||
resource: str,
|
||||
action: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep)
|
||||
) -> str:
|
||||
"""
|
||||
FastAPI dependency to verify tenant access AND resource permission
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID from path parameter
|
||||
resource: Resource type being accessed
|
||||
action: Action being performed
|
||||
current_user: Current user from auth dependency
|
||||
|
||||
Returns:
|
||||
str: Validated tenant_id
|
||||
|
||||
Raises:
|
||||
HTTPException: If user doesn't have access or permission
|
||||
"""
|
||||
# First verify basic tenant access
|
||||
has_access = await tenant_access_manager.verify_basic_tenant_access(current_user["user_id"], tenant_id)
|
||||
if not has_access:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Access denied to tenant {tenant_id}"
|
||||
)
|
||||
|
||||
# Then verify specific resource permission
|
||||
has_permission = await tenant_access_manager.verify_resource_permission(
|
||||
current_user["user_id"], tenant_id, resource, action
|
||||
)
|
||||
if not has_permission:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Insufficient permissions for {action} on {resource}"
|
||||
)
|
||||
|
||||
logger.debug(f"Tenant access and permission verified",
|
||||
user_id=current_user["user_id"],
|
||||
tenant_id=tenant_id,
|
||||
resource=resource,
|
||||
action=action)
|
||||
|
||||
return tenant_id
|
||||
|
||||
# ================================================================
|
||||
# UTILITY FUNCTIONS
|
||||
# ================================================================
|
||||
|
||||
def extract_tenant_id_from_path(path: str) -> Optional[str]:
|
||||
"""
|
||||
More robust tenant ID extraction using regex pattern matching
|
||||
Only matches actual tenant-scoped paths with UUID format
|
||||
"""
|
||||
# Pattern for tenant-scoped paths: /api/v1/tenants/{uuid}/...
|
||||
tenant_pattern = r'/api/v1/tenants/([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})/.*'
|
||||
|
||||
match = re.match(tenant_pattern, path, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
def is_tenant_scoped_path(path: str) -> bool:
|
||||
"""
|
||||
Check if path is tenant-scoped (contains /tenants/{tenant_id}/)
|
||||
|
||||
Args:
|
||||
path: URL path
|
||||
|
||||
Returns:
|
||||
bool: True if path is tenant-scoped
|
||||
"""
|
||||
return extract_tenant_id_from_path(path) is not None
|
||||
|
||||
# ================================================================
|
||||
# EXPORTS
|
||||
# ================================================================
|
||||
|
||||
__all__ = [
|
||||
# Classes
|
||||
"TenantAccessManager",
|
||||
"tenant_access_manager",
|
||||
|
||||
# Dependencies
|
||||
"verify_tenant_access_dep",
|
||||
"verify_tenant_permission_dep",
|
||||
|
||||
# Utilities
|
||||
"extract_tenant_id_from_path",
|
||||
"is_tenant_scoped_path"
|
||||
]
|
||||
Reference in New Issue
Block a user