Add role-based filtering and imporve code
This commit is contained in:
388
shared/security/rate_limiter.py
Normal file
388
shared/security/rate_limiter.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
Rate limiting and quota management system for subscription-based features
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import structlog
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class QuotaType(str, Enum):
|
||||
"""Types of quotas"""
|
||||
FORECAST_GENERATION = "forecast_generation"
|
||||
TRAINING_JOBS = "training_jobs"
|
||||
BULK_IMPORTS = "bulk_imports"
|
||||
POS_SYNC = "pos_sync"
|
||||
API_CALLS = "api_calls"
|
||||
DEMO_SESSIONS = "demo_sessions"
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Redis-based rate limiter for subscription tier quotas
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client):
|
||||
"""
|
||||
Initialize rate limiter
|
||||
|
||||
Args:
|
||||
redis_client: Redis client for storing quota counters
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.logger = logger
|
||||
|
||||
def _get_quota_key(self, tenant_id: str, quota_type: str, period: str = "daily") -> str:
|
||||
"""Generate Redis key for quota tracking"""
|
||||
date_str = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
return f"quota:{period}:{quota_type}:{tenant_id}:{date_str}"
|
||||
|
||||
def _get_dataset_size_key(self, tenant_id: str) -> str:
|
||||
"""Generate Redis key for dataset size tracking"""
|
||||
return f"dataset_size:{tenant_id}"
|
||||
|
||||
async def check_and_increment_quota(
|
||||
self,
|
||||
tenant_id: str,
|
||||
quota_type: str,
|
||||
limit: Optional[int],
|
||||
period: int = 86400 # 24 hours in seconds
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if quota allows action and increment counter
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
quota_type: Type of quota to check
|
||||
limit: Maximum allowed count (None = unlimited)
|
||||
period: Time period in seconds (default: 24 hours)
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- allowed: bool
|
||||
- current: int (current count)
|
||||
- limit: Optional[int]
|
||||
- reset_at: datetime (when quota resets)
|
||||
|
||||
Raises:
|
||||
HTTPException: If quota is exceeded
|
||||
"""
|
||||
if limit is None:
|
||||
# Unlimited quota
|
||||
return {
|
||||
"allowed": True,
|
||||
"current": 0,
|
||||
"limit": None,
|
||||
"reset_at": None
|
||||
}
|
||||
|
||||
key = self._get_quota_key(tenant_id, quota_type)
|
||||
|
||||
try:
|
||||
# Get current count
|
||||
current = await self.redis.get(key)
|
||||
current_count = int(current) if current else 0
|
||||
|
||||
# Check if limit exceeded
|
||||
if current_count >= limit:
|
||||
ttl = await self.redis.ttl(key)
|
||||
reset_at = datetime.utcnow() + timedelta(seconds=ttl if ttl > 0 else period)
|
||||
|
||||
self.logger.warning(
|
||||
"quota_exceeded",
|
||||
tenant_id=tenant_id,
|
||||
quota_type=quota_type,
|
||||
current=current_count,
|
||||
limit=limit,
|
||||
reset_at=reset_at.isoformat()
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail={
|
||||
"error": "quota_exceeded",
|
||||
"message": f"Daily quota exceeded for {quota_type}",
|
||||
"current": current_count,
|
||||
"limit": limit,
|
||||
"reset_at": reset_at.isoformat(),
|
||||
"quota_type": quota_type
|
||||
}
|
||||
)
|
||||
|
||||
# Increment counter
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.incr(key)
|
||||
pipe.expire(key, period)
|
||||
await pipe.execute()
|
||||
|
||||
new_count = current_count + 1
|
||||
ttl = await self.redis.ttl(key)
|
||||
reset_at = datetime.utcnow() + timedelta(seconds=ttl if ttl > 0 else period)
|
||||
|
||||
self.logger.info(
|
||||
"quota_incremented",
|
||||
tenant_id=tenant_id,
|
||||
quota_type=quota_type,
|
||||
current=new_count,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"current": new_count,
|
||||
"limit": limit,
|
||||
"reset_at": reset_at
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"quota_check_failed",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id,
|
||||
quota_type=quota_type
|
||||
)
|
||||
# Fail open - allow the operation
|
||||
return {
|
||||
"allowed": True,
|
||||
"current": 0,
|
||||
"limit": limit,
|
||||
"reset_at": None
|
||||
}
|
||||
|
||||
async def get_current_usage(
|
||||
self,
|
||||
tenant_id: str,
|
||||
quota_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current quota usage without incrementing
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
quota_type: Type of quota to check
|
||||
|
||||
Returns:
|
||||
Dict with current usage information
|
||||
"""
|
||||
key = self._get_quota_key(tenant_id, quota_type)
|
||||
|
||||
try:
|
||||
current = await self.redis.get(key)
|
||||
current_count = int(current) if current else 0
|
||||
ttl = await self.redis.ttl(key)
|
||||
reset_at = datetime.utcnow() + timedelta(seconds=ttl) if ttl > 0 else None
|
||||
|
||||
return {
|
||||
"current": current_count,
|
||||
"reset_at": reset_at
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"usage_check_failed",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id,
|
||||
quota_type=quota_type
|
||||
)
|
||||
return {
|
||||
"current": 0,
|
||||
"reset_at": None
|
||||
}
|
||||
|
||||
async def reset_quota(self, tenant_id: str, quota_type: str):
|
||||
"""
|
||||
Reset quota for a tenant (admin function)
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
quota_type: Type of quota to reset
|
||||
"""
|
||||
key = self._get_quota_key(tenant_id, quota_type)
|
||||
try:
|
||||
await self.redis.delete(key)
|
||||
self.logger.info(
|
||||
"quota_reset",
|
||||
tenant_id=tenant_id,
|
||||
quota_type=quota_type
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"quota_reset_failed",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id,
|
||||
quota_type=quota_type
|
||||
)
|
||||
|
||||
async def validate_dataset_size(
|
||||
self,
|
||||
tenant_id: str,
|
||||
dataset_rows: int,
|
||||
subscription_tier: str
|
||||
):
|
||||
"""
|
||||
Validate dataset size against subscription tier limits
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
dataset_rows: Number of rows in dataset
|
||||
subscription_tier: User's subscription tier
|
||||
|
||||
Raises:
|
||||
HTTPException: If dataset size exceeds limit
|
||||
"""
|
||||
# Dataset size limits per tier
|
||||
dataset_limits = {
|
||||
'starter': 1000,
|
||||
'professional': 10000,
|
||||
'enterprise': None # Unlimited
|
||||
}
|
||||
|
||||
limit = dataset_limits.get(subscription_tier.lower())
|
||||
|
||||
if limit is not None and dataset_rows > limit:
|
||||
self.logger.warning(
|
||||
"dataset_size_exceeded",
|
||||
tenant_id=tenant_id,
|
||||
dataset_rows=dataset_rows,
|
||||
limit=limit,
|
||||
tier=subscription_tier
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"error": "dataset_size_limit_exceeded",
|
||||
"message": f"Dataset size limited to {limit:,} rows for {subscription_tier} tier",
|
||||
"current_size": dataset_rows,
|
||||
"limit": limit,
|
||||
"tier": subscription_tier,
|
||||
"upgrade_url": "/app/settings/profile"
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"dataset_size_validated",
|
||||
tenant_id=tenant_id,
|
||||
dataset_rows=dataset_rows,
|
||||
tier=subscription_tier
|
||||
)
|
||||
|
||||
async def validate_forecast_horizon(
|
||||
self,
|
||||
tenant_id: str,
|
||||
horizon_days: int,
|
||||
subscription_tier: str
|
||||
):
|
||||
"""
|
||||
Validate forecast horizon against subscription tier limits
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
horizon_days: Number of days to forecast
|
||||
subscription_tier: User's subscription tier
|
||||
|
||||
Raises:
|
||||
HTTPException: If horizon exceeds limit
|
||||
"""
|
||||
# Forecast horizon limits per tier
|
||||
horizon_limits = {
|
||||
'starter': 7,
|
||||
'professional': 90,
|
||||
'enterprise': 365 # Practically unlimited
|
||||
}
|
||||
|
||||
limit = horizon_limits.get(subscription_tier.lower(), 7)
|
||||
|
||||
if horizon_days > limit:
|
||||
self.logger.warning(
|
||||
"forecast_horizon_exceeded",
|
||||
tenant_id=tenant_id,
|
||||
horizon_days=horizon_days,
|
||||
limit=limit,
|
||||
tier=subscription_tier
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"error": "forecast_horizon_limit_exceeded",
|
||||
"message": f"Forecast horizon limited to {limit} days for {subscription_tier} tier",
|
||||
"requested_horizon": horizon_days,
|
||||
"limit": limit,
|
||||
"tier": subscription_tier,
|
||||
"upgrade_url": "/app/settings/profile"
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"forecast_horizon_validated",
|
||||
tenant_id=tenant_id,
|
||||
horizon_days=horizon_days,
|
||||
tier=subscription_tier
|
||||
)
|
||||
|
||||
async def validate_historical_data_access(
|
||||
self,
|
||||
tenant_id: str,
|
||||
days_back: int,
|
||||
subscription_tier: str
|
||||
):
|
||||
"""
|
||||
Validate historical data access against subscription tier limits
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
days_back: Number of days of historical data requested
|
||||
subscription_tier: User's subscription tier
|
||||
|
||||
Raises:
|
||||
HTTPException: If historical data access exceeds limit
|
||||
"""
|
||||
# Historical data limits per tier
|
||||
history_limits = {
|
||||
'starter': 7,
|
||||
'professional': 90,
|
||||
'enterprise': None # Unlimited
|
||||
}
|
||||
|
||||
limit = history_limits.get(subscription_tier.lower(), 7)
|
||||
|
||||
if limit is not None and days_back > limit:
|
||||
self.logger.warning(
|
||||
"historical_data_limit_exceeded",
|
||||
tenant_id=tenant_id,
|
||||
days_back=days_back,
|
||||
limit=limit,
|
||||
tier=subscription_tier
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"error": "historical_data_limit_exceeded",
|
||||
"message": f"Historical data limited to {limit} days for {subscription_tier} tier",
|
||||
"requested_days": days_back,
|
||||
"limit": limit,
|
||||
"tier": subscription_tier,
|
||||
"upgrade_url": "/app/settings/profile"
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"historical_data_access_validated",
|
||||
tenant_id=tenant_id,
|
||||
days_back=days_back,
|
||||
tier=subscription_tier
|
||||
)
|
||||
|
||||
|
||||
def create_rate_limiter(redis_client) -> RateLimiter:
|
||||
"""Factory function to create rate limiter"""
|
||||
return RateLimiter(redis_client)
|
||||
Reference in New Issue
Block a user