479 lines
17 KiB
Python
Executable File
479 lines
17 KiB
Python
Executable File
"""
|
|
Subscription Tier and Role-Based Access Control Decorators
|
|
Provides unified access control across all microservices
|
|
"""
|
|
|
|
from enum import Enum
|
|
from functools import wraps
|
|
from typing import List, Callable, Dict, Any, Optional
|
|
from fastapi import HTTPException, status, Request, Depends
|
|
import structlog
|
|
|
|
from shared.auth.decorators import get_current_user_dep
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class SubscriptionTier(Enum):
|
|
"""
|
|
Subscription tier hierarchy
|
|
Matches project-wide subscription plans in tenant service
|
|
"""
|
|
STARTER = "starter"
|
|
PROFESSIONAL = "professional"
|
|
ENTERPRISE = "enterprise"
|
|
|
|
|
|
class UserRole(Enum):
|
|
"""
|
|
User role hierarchy
|
|
Matches project-wide role definitions in tenant member model
|
|
"""
|
|
VIEWER = "viewer"
|
|
MEMBER = "member"
|
|
ADMIN = "admin"
|
|
OWNER = "owner"
|
|
|
|
|
|
# Tier hierarchy for comparison (higher number = higher tier)
|
|
TIER_HIERARCHY = {
|
|
SubscriptionTier.STARTER: 1,
|
|
SubscriptionTier.PROFESSIONAL: 2,
|
|
SubscriptionTier.ENTERPRISE: 3,
|
|
}
|
|
|
|
# Role hierarchy for comparison (higher number = more permissions)
|
|
ROLE_HIERARCHY = {
|
|
UserRole.VIEWER: 1,
|
|
UserRole.MEMBER: 2,
|
|
UserRole.ADMIN: 3,
|
|
UserRole.OWNER: 4,
|
|
}
|
|
|
|
|
|
def check_tier_access(user_tier: str, required_tiers: List[str]) -> bool:
|
|
"""
|
|
Check if user's subscription tier meets the requirement
|
|
|
|
Args:
|
|
user_tier: Current user's subscription tier
|
|
required_tiers: List of allowed tiers
|
|
|
|
Returns:
|
|
bool: True if access is allowed
|
|
"""
|
|
try:
|
|
user_tier_enum = SubscriptionTier(user_tier.lower())
|
|
user_tier_level = TIER_HIERARCHY.get(user_tier_enum, 0)
|
|
|
|
# Get minimum required tier level
|
|
min_required_level = min(
|
|
TIER_HIERARCHY.get(SubscriptionTier(tier.lower()), 999)
|
|
for tier in required_tiers
|
|
)
|
|
|
|
return user_tier_level >= min_required_level
|
|
except (ValueError, KeyError) as e:
|
|
logger.warning("Invalid tier comparison", user_tier=user_tier, required=required_tiers, error=str(e))
|
|
return False
|
|
|
|
|
|
def check_role_access(user_role: str, required_roles: List[str]) -> bool:
|
|
"""
|
|
Check if user's role meets the requirement
|
|
|
|
Args:
|
|
user_role: Current user's role
|
|
required_roles: List of allowed roles
|
|
|
|
Returns:
|
|
bool: True if access is allowed
|
|
"""
|
|
try:
|
|
user_role_enum = UserRole(user_role.lower())
|
|
user_role_level = ROLE_HIERARCHY.get(user_role_enum, 0)
|
|
|
|
# Get minimum required role level
|
|
min_required_level = min(
|
|
ROLE_HIERARCHY.get(UserRole(role.lower()), 999)
|
|
for role in required_roles
|
|
)
|
|
|
|
return user_role_level >= min_required_level
|
|
except (ValueError, KeyError) as e:
|
|
logger.warning("Invalid role comparison", user_role=user_role, required=required_roles, error=str(e))
|
|
return False
|
|
|
|
|
|
def require_subscription_tier(allowed_tiers: List[str]):
|
|
"""
|
|
Decorator to enforce subscription tier access control
|
|
|
|
Usage:
|
|
@router.get("/analytics/advanced")
|
|
@require_subscription_tier(['professional', 'enterprise'])
|
|
async def get_advanced_analytics(...):
|
|
...
|
|
|
|
Args:
|
|
allowed_tiers: List of subscription tiers allowed to access this endpoint
|
|
"""
|
|
def decorator(func: Callable) -> Callable:
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
# Get current user from kwargs (injected by get_current_user_dep)
|
|
current_user = kwargs.get('current_user')
|
|
|
|
if not current_user:
|
|
# Try to find in args
|
|
for arg in args:
|
|
if isinstance(arg, dict) and 'user_id' in arg:
|
|
current_user = arg
|
|
break
|
|
|
|
if not current_user:
|
|
logger.error("Current user not found in request context")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Authentication required"
|
|
)
|
|
|
|
# Get tenant's subscription tier from user context
|
|
# The gateway should inject this information
|
|
subscription_tier = current_user.get('subscription_tier')
|
|
|
|
if not subscription_tier:
|
|
logger.warning("Subscription tier not found in user context", user_id=current_user.get('user_id'))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Subscription information unavailable"
|
|
)
|
|
|
|
# Check tier access
|
|
has_access = check_tier_access(subscription_tier, allowed_tiers)
|
|
|
|
if not has_access:
|
|
logger.warning(
|
|
"Subscription tier access denied",
|
|
user_tier=subscription_tier,
|
|
required_tiers=allowed_tiers,
|
|
user_id=current_user.get('user_id')
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
detail={
|
|
"error": "subscription_tier_insufficient",
|
|
"message": f"This feature requires a {' or '.join(allowed_tiers)} subscription plan",
|
|
"current_plan": subscription_tier,
|
|
"required_plans": allowed_tiers,
|
|
"upgrade_url": "/app/settings/profile"
|
|
}
|
|
)
|
|
|
|
logger.debug("Subscription tier check passed", tier=subscription_tier, required=allowed_tiers)
|
|
return await func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
def require_user_role(allowed_roles: List[str]):
|
|
"""
|
|
Decorator to enforce role-based access control
|
|
|
|
Usage:
|
|
@router.delete("/ingredients/{id}")
|
|
@require_user_role(['admin', 'manager'])
|
|
async def delete_ingredient(...):
|
|
...
|
|
|
|
Args:
|
|
allowed_roles: List of user roles allowed to access this endpoint
|
|
"""
|
|
def decorator(func: Callable) -> Callable:
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
# Get current user from kwargs
|
|
current_user = kwargs.get('current_user')
|
|
|
|
if not current_user:
|
|
# Try to find in args
|
|
for arg in args:
|
|
if isinstance(arg, dict) and 'user_id' in arg:
|
|
current_user = arg
|
|
break
|
|
|
|
if not current_user:
|
|
logger.error("Current user not found in request context")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Authentication required"
|
|
)
|
|
|
|
# Get user's role
|
|
user_role = current_user.get('role', 'user')
|
|
|
|
# Check role access
|
|
has_access = check_role_access(user_role, allowed_roles)
|
|
|
|
if not has_access:
|
|
logger.warning(
|
|
"Role-based access denied",
|
|
user_role=user_role,
|
|
required_roles=allowed_roles,
|
|
user_id=current_user.get('user_id')
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail={
|
|
"error": "insufficient_permissions",
|
|
"message": f"This action requires {' or '.join(allowed_roles)} role",
|
|
"current_role": user_role,
|
|
"required_roles": allowed_roles
|
|
}
|
|
)
|
|
|
|
logger.debug("Role check passed", role=user_role, required=allowed_roles)
|
|
return await func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
def require_tier_and_role(
|
|
allowed_tiers: List[str],
|
|
allowed_roles: List[str]
|
|
):
|
|
"""
|
|
Combined decorator for both tier and role enforcement
|
|
|
|
Usage:
|
|
@router.post("/analytics/custom-report")
|
|
@require_tier_and_role(['professional', 'enterprise'], ['admin', 'manager'])
|
|
async def create_custom_report(...):
|
|
...
|
|
|
|
Args:
|
|
allowed_tiers: List of subscription tiers allowed
|
|
allowed_roles: List of user roles allowed
|
|
"""
|
|
def decorator(func: Callable) -> Callable:
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
# Get current user from kwargs
|
|
current_user = kwargs.get('current_user')
|
|
|
|
if not current_user:
|
|
# Try to find in args
|
|
for arg in args:
|
|
if isinstance(arg, dict) and 'user_id' in arg:
|
|
current_user = arg
|
|
break
|
|
|
|
if not current_user:
|
|
logger.error("Current user not found in request context")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Authentication required"
|
|
)
|
|
|
|
# Check subscription tier
|
|
subscription_tier = current_user.get('subscription_tier')
|
|
if subscription_tier:
|
|
tier_access = check_tier_access(subscription_tier, allowed_tiers)
|
|
if not tier_access:
|
|
logger.warning(
|
|
"Combined access control: tier check failed",
|
|
user_tier=subscription_tier,
|
|
required_tiers=allowed_tiers
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
detail={
|
|
"error": "subscription_tier_insufficient",
|
|
"message": f"This feature requires a {' or '.join(allowed_tiers)} subscription plan",
|
|
"current_plan": subscription_tier,
|
|
"required_plans": allowed_tiers,
|
|
"upgrade_url": "/app/settings/profile"
|
|
}
|
|
)
|
|
|
|
# Check user role
|
|
user_role = current_user.get('role', 'member')
|
|
role_access = check_role_access(user_role, allowed_roles)
|
|
|
|
if not role_access:
|
|
logger.warning(
|
|
"Combined access control: role check failed",
|
|
user_role=user_role,
|
|
required_roles=allowed_roles
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail={
|
|
"error": "insufficient_permissions",
|
|
"message": f"This action requires {' or '.join(allowed_roles)} role",
|
|
"current_role": user_role,
|
|
"required_roles": allowed_roles
|
|
}
|
|
)
|
|
|
|
logger.debug(
|
|
"Combined access control passed",
|
|
tier=subscription_tier,
|
|
role=user_role,
|
|
required_tiers=allowed_tiers,
|
|
required_roles=allowed_roles
|
|
)
|
|
return await func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
# Convenience decorators for common patterns
|
|
analytics_tier_required = require_subscription_tier(['professional', 'enterprise'])
|
|
enterprise_tier_required = require_subscription_tier(['enterprise'])
|
|
admin_role_required = require_user_role(['admin', 'owner'])
|
|
owner_role_required = require_user_role(['owner'])
|
|
|
|
|
|
def service_only_access(func: Callable) -> Callable:
|
|
"""
|
|
Decorator to restrict endpoint access to service-to-service calls only
|
|
|
|
This decorator validates that:
|
|
1. The request has a valid service token (type='service' in JWT)
|
|
2. The token is from an authorized internal service
|
|
|
|
Usage:
|
|
@router.delete("/tenant/{tenant_id}")
|
|
@service_only_access
|
|
async def delete_tenant_data(
|
|
tenant_id: str,
|
|
current_user: dict = Depends(get_current_user_dep),
|
|
db = Depends(get_db)
|
|
):
|
|
# Service-only logic here
|
|
|
|
The decorator expects current_user to be injected via get_current_user_dep
|
|
dependency, which should already contain the user/service context from JWT.
|
|
"""
|
|
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
# Get current user from kwargs (injected by get_current_user_dep)
|
|
current_user = kwargs.get('current_user')
|
|
|
|
if not current_user:
|
|
# Try to find in args
|
|
for arg in args:
|
|
if isinstance(arg, dict) and 'user_id' in arg:
|
|
current_user = arg
|
|
break
|
|
|
|
if not current_user:
|
|
logger.error("Service-only access: current user not found in request context")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Authentication required"
|
|
)
|
|
|
|
# Check if this is a service token
|
|
user_type = current_user.get('type', '')
|
|
is_service = current_user.get('is_service', False)
|
|
|
|
if user_type != 'service' and not is_service:
|
|
logger.warning(
|
|
"Service-only access denied: not a service token",
|
|
user_id=current_user.get('user_id'),
|
|
user_type=user_type,
|
|
is_service=is_service
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="This endpoint is only accessible to internal services"
|
|
)
|
|
|
|
# Log successful service access
|
|
service_name = current_user.get('service', current_user.get('user_id', 'unknown'))
|
|
logger.info(
|
|
"Service-only access granted",
|
|
service=service_name,
|
|
endpoint=func.__name__
|
|
)
|
|
|
|
return await func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def require_verified_subscription_tier(
|
|
allowed_tiers: List[str],
|
|
verify_in_database: bool = False
|
|
):
|
|
"""
|
|
Subscription tier enforcement with optional database verification.
|
|
|
|
Args:
|
|
allowed_tiers: List of allowed subscription tiers
|
|
verify_in_database: If True, verify against database (for critical operations)
|
|
"""
|
|
def decorator(func):
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
request = kwargs.get('request') or args[0]
|
|
|
|
# Get tier from gateway-injected header (from verified JWT)
|
|
header_tier = request.headers.get("x-subscription-tier", "starter").lower()
|
|
|
|
if header_tier not in [t.lower() for t in allowed_tiers]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
detail={
|
|
"error": "subscription_required",
|
|
"message": f"This feature requires {', '.join(allowed_tiers)} subscription",
|
|
"current_tier": header_tier,
|
|
"required_tiers": allowed_tiers
|
|
}
|
|
)
|
|
|
|
# For critical operations, verify against database
|
|
if verify_in_database:
|
|
tenant_id = request.headers.get("x-tenant-id")
|
|
if tenant_id:
|
|
db_tier = await _verify_subscription_in_database(tenant_id)
|
|
if db_tier.lower() != header_tier:
|
|
logger.error(
|
|
"Subscription tier mismatch detected!",
|
|
header_tier=header_tier,
|
|
db_tier=db_tier,
|
|
tenant_id=tenant_id,
|
|
user_id=request.headers.get("x-user-id")
|
|
)
|
|
# Use database tier (authoritative)
|
|
if db_tier.lower() not in [t.lower() for t in allowed_tiers]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
detail={
|
|
"error": "subscription_verification_failed",
|
|
"message": "Subscription tier verification failed"
|
|
}
|
|
)
|
|
|
|
return await func(*args, **kwargs)
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
async def _verify_subscription_in_database(tenant_id: str) -> str:
|
|
"""
|
|
Direct database verification of subscription tier.
|
|
Used for critical operations as defense-in-depth.
|
|
"""
|
|
from shared.clients.subscription_client import SubscriptionClient
|
|
|
|
client = SubscriptionClient()
|
|
subscription = await client.get_subscription(tenant_id)
|
|
return subscription.get("plan", "starter")
|