REFACTOR ALL APIs
This commit is contained in:
338
shared/auth/access_control.py
Normal file
338
shared/auth/access_control.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Subscription Tier and Role-Based Access Control Decorators
|
||||
Provides unified access control across all microservices
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import List, Callable, Dict, Any, Optional
|
||||
from fastapi import HTTPException, status, Request, Depends
|
||||
import structlog
|
||||
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SubscriptionTier(Enum):
|
||||
"""
|
||||
Subscription tier hierarchy
|
||||
Matches project-wide subscription plans in tenant service
|
||||
"""
|
||||
STARTER = "starter"
|
||||
PROFESSIONAL = "professional"
|
||||
ENTERPRISE = "enterprise"
|
||||
|
||||
|
||||
class UserRole(Enum):
|
||||
"""
|
||||
User role hierarchy
|
||||
Matches project-wide role definitions in tenant member model
|
||||
"""
|
||||
VIEWER = "viewer"
|
||||
MEMBER = "member"
|
||||
ADMIN = "admin"
|
||||
OWNER = "owner"
|
||||
|
||||
|
||||
# Tier hierarchy for comparison (higher number = higher tier)
|
||||
TIER_HIERARCHY = {
|
||||
SubscriptionTier.STARTER: 1,
|
||||
SubscriptionTier.PROFESSIONAL: 2,
|
||||
SubscriptionTier.ENTERPRISE: 3,
|
||||
}
|
||||
|
||||
# Role hierarchy for comparison (higher number = more permissions)
|
||||
ROLE_HIERARCHY = {
|
||||
UserRole.VIEWER: 1,
|
||||
UserRole.MEMBER: 2,
|
||||
UserRole.ADMIN: 3,
|
||||
UserRole.OWNER: 4,
|
||||
}
|
||||
|
||||
|
||||
def check_tier_access(user_tier: str, required_tiers: List[str]) -> bool:
|
||||
"""
|
||||
Check if user's subscription tier meets the requirement
|
||||
|
||||
Args:
|
||||
user_tier: Current user's subscription tier
|
||||
required_tiers: List of allowed tiers
|
||||
|
||||
Returns:
|
||||
bool: True if access is allowed
|
||||
"""
|
||||
try:
|
||||
user_tier_enum = SubscriptionTier(user_tier.lower())
|
||||
user_tier_level = TIER_HIERARCHY.get(user_tier_enum, 0)
|
||||
|
||||
# Get minimum required tier level
|
||||
min_required_level = min(
|
||||
TIER_HIERARCHY.get(SubscriptionTier(tier.lower()), 999)
|
||||
for tier in required_tiers
|
||||
)
|
||||
|
||||
return user_tier_level >= min_required_level
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.warning("Invalid tier comparison", user_tier=user_tier, required=required_tiers, error=str(e))
|
||||
return False
|
||||
|
||||
|
||||
def check_role_access(user_role: str, required_roles: List[str]) -> bool:
|
||||
"""
|
||||
Check if user's role meets the requirement
|
||||
|
||||
Args:
|
||||
user_role: Current user's role
|
||||
required_roles: List of allowed roles
|
||||
|
||||
Returns:
|
||||
bool: True if access is allowed
|
||||
"""
|
||||
try:
|
||||
user_role_enum = UserRole(user_role.lower())
|
||||
user_role_level = ROLE_HIERARCHY.get(user_role_enum, 0)
|
||||
|
||||
# Get minimum required role level
|
||||
min_required_level = min(
|
||||
ROLE_HIERARCHY.get(UserRole(role.lower()), 999)
|
||||
for role in required_roles
|
||||
)
|
||||
|
||||
return user_role_level >= min_required_level
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.warning("Invalid role comparison", user_role=user_role, required=required_roles, error=str(e))
|
||||
return False
|
||||
|
||||
|
||||
def require_subscription_tier(allowed_tiers: List[str]):
|
||||
"""
|
||||
Decorator to enforce subscription tier access control
|
||||
|
||||
Usage:
|
||||
@router.get("/analytics/advanced")
|
||||
@require_subscription_tier(['professional', 'enterprise'])
|
||||
async def get_advanced_analytics(...):
|
||||
...
|
||||
|
||||
Args:
|
||||
allowed_tiers: List of subscription tiers allowed to access this endpoint
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Get current user from kwargs (injected by get_current_user_dep)
|
||||
current_user = kwargs.get('current_user')
|
||||
|
||||
if not current_user:
|
||||
# Try to find in args
|
||||
for arg in args:
|
||||
if isinstance(arg, dict) and 'user_id' in arg:
|
||||
current_user = arg
|
||||
break
|
||||
|
||||
if not current_user:
|
||||
logger.error("Current user not found in request context")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
# Get tenant's subscription tier from user context
|
||||
# The gateway should inject this information
|
||||
subscription_tier = current_user.get('subscription_tier')
|
||||
|
||||
if not subscription_tier:
|
||||
logger.warning("Subscription tier not found in user context", user_id=current_user.get('user_id'))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Subscription information unavailable"
|
||||
)
|
||||
|
||||
# Check tier access
|
||||
has_access = check_tier_access(subscription_tier, allowed_tiers)
|
||||
|
||||
if not has_access:
|
||||
logger.warning(
|
||||
"Subscription tier access denied",
|
||||
user_tier=subscription_tier,
|
||||
required_tiers=allowed_tiers,
|
||||
user_id=current_user.get('user_id')
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"error": "subscription_tier_insufficient",
|
||||
"message": f"This feature requires a {' or '.join(allowed_tiers)} subscription plan",
|
||||
"current_plan": subscription_tier,
|
||||
"required_plans": allowed_tiers,
|
||||
"upgrade_url": "/app/settings/profile"
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug("Subscription tier check passed", tier=subscription_tier, required=allowed_tiers)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def require_user_role(allowed_roles: List[str]):
|
||||
"""
|
||||
Decorator to enforce role-based access control
|
||||
|
||||
Usage:
|
||||
@router.delete("/ingredients/{id}")
|
||||
@require_user_role(['admin', 'manager'])
|
||||
async def delete_ingredient(...):
|
||||
...
|
||||
|
||||
Args:
|
||||
allowed_roles: List of user roles allowed to access this endpoint
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Get current user from kwargs
|
||||
current_user = kwargs.get('current_user')
|
||||
|
||||
if not current_user:
|
||||
# Try to find in args
|
||||
for arg in args:
|
||||
if isinstance(arg, dict) and 'user_id' in arg:
|
||||
current_user = arg
|
||||
break
|
||||
|
||||
if not current_user:
|
||||
logger.error("Current user not found in request context")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
# Get user's role
|
||||
user_role = current_user.get('role', 'user')
|
||||
|
||||
# Check role access
|
||||
has_access = check_role_access(user_role, allowed_roles)
|
||||
|
||||
if not has_access:
|
||||
logger.warning(
|
||||
"Role-based access denied",
|
||||
user_role=user_role,
|
||||
required_roles=allowed_roles,
|
||||
user_id=current_user.get('user_id')
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"error": "insufficient_permissions",
|
||||
"message": f"This action requires {' or '.join(allowed_roles)} role",
|
||||
"current_role": user_role,
|
||||
"required_roles": allowed_roles
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug("Role check passed", role=user_role, required=allowed_roles)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def require_tier_and_role(
|
||||
allowed_tiers: List[str],
|
||||
allowed_roles: List[str]
|
||||
):
|
||||
"""
|
||||
Combined decorator for both tier and role enforcement
|
||||
|
||||
Usage:
|
||||
@router.post("/analytics/custom-report")
|
||||
@require_tier_and_role(['professional', 'enterprise'], ['admin', 'manager'])
|
||||
async def create_custom_report(...):
|
||||
...
|
||||
|
||||
Args:
|
||||
allowed_tiers: List of subscription tiers allowed
|
||||
allowed_roles: List of user roles allowed
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Get current user from kwargs
|
||||
current_user = kwargs.get('current_user')
|
||||
|
||||
if not current_user:
|
||||
# Try to find in args
|
||||
for arg in args:
|
||||
if isinstance(arg, dict) and 'user_id' in arg:
|
||||
current_user = arg
|
||||
break
|
||||
|
||||
if not current_user:
|
||||
logger.error("Current user not found in request context")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
# Check subscription tier
|
||||
subscription_tier = current_user.get('subscription_tier')
|
||||
if subscription_tier:
|
||||
tier_access = check_tier_access(subscription_tier, allowed_tiers)
|
||||
if not tier_access:
|
||||
logger.warning(
|
||||
"Combined access control: tier check failed",
|
||||
user_tier=subscription_tier,
|
||||
required_tiers=allowed_tiers
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"error": "subscription_tier_insufficient",
|
||||
"message": f"This feature requires a {' or '.join(allowed_tiers)} subscription plan",
|
||||
"current_plan": subscription_tier,
|
||||
"required_plans": allowed_tiers,
|
||||
"upgrade_url": "/app/settings/profile"
|
||||
}
|
||||
)
|
||||
|
||||
# Check user role
|
||||
user_role = current_user.get('role', 'member')
|
||||
role_access = check_role_access(user_role, allowed_roles)
|
||||
|
||||
if not role_access:
|
||||
logger.warning(
|
||||
"Combined access control: role check failed",
|
||||
user_role=user_role,
|
||||
required_roles=allowed_roles
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"error": "insufficient_permissions",
|
||||
"message": f"This action requires {' or '.join(allowed_roles)} role",
|
||||
"current_role": user_role,
|
||||
"required_roles": allowed_roles
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Combined access control passed",
|
||||
tier=subscription_tier,
|
||||
role=user_role,
|
||||
required_tiers=allowed_tiers,
|
||||
required_roles=allowed_roles
|
||||
)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# Convenience decorators for common patterns
|
||||
analytics_tier_required = require_subscription_tier(['professional', 'enterprise'])
|
||||
enterprise_tier_required = require_subscription_tier(['enterprise'])
|
||||
admin_role_required = require_user_role(['admin', 'owner'])
|
||||
owner_role_required = require_user_role(['owner'])
|
||||
Reference in New Issue
Block a user