Files
bakery-ia/shared/security/rate_limiter.py

389 lines
12 KiB
Python
Executable File

"""
Rate limiting and quota management system for subscription-based features
"""
import time
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
from enum import Enum
import structlog
from fastapi import HTTPException, status
logger = structlog.get_logger()
class QuotaType(str, Enum):
"""Types of quotas"""
FORECAST_GENERATION = "forecast_generation"
TRAINING_JOBS = "training_jobs"
BULK_IMPORTS = "bulk_imports"
POS_SYNC = "pos_sync"
API_CALLS = "api_calls"
DEMO_SESSIONS = "demo_sessions"
class RateLimiter:
"""
Redis-based rate limiter for subscription tier quotas
"""
def __init__(self, redis_client):
"""
Initialize rate limiter
Args:
redis_client: Redis client for storing quota counters
"""
self.redis = redis_client
self.logger = logger
def _get_quota_key(self, tenant_id: str, quota_type: str, period: str = "daily") -> str:
"""Generate Redis key for quota tracking"""
date_str = datetime.utcnow().strftime("%Y-%m-%d")
return f"quota:{period}:{quota_type}:{tenant_id}:{date_str}"
def _get_dataset_size_key(self, tenant_id: str) -> str:
"""Generate Redis key for dataset size tracking"""
return f"dataset_size:{tenant_id}"
async def check_and_increment_quota(
self,
tenant_id: str,
quota_type: str,
limit: Optional[int],
period: int = 86400 # 24 hours in seconds
) -> Dict[str, Any]:
"""
Check if quota allows action and increment counter
Args:
tenant_id: Tenant ID
quota_type: Type of quota to check
limit: Maximum allowed count (None = unlimited)
period: Time period in seconds (default: 24 hours)
Returns:
Dict with:
- allowed: bool
- current: int (current count)
- limit: Optional[int]
- reset_at: datetime (when quota resets)
Raises:
HTTPException: If quota is exceeded
"""
if limit is None:
# Unlimited quota
return {
"allowed": True,
"current": 0,
"limit": None,
"reset_at": None
}
key = self._get_quota_key(tenant_id, quota_type)
try:
# Get current count
current = await self.redis.get(key)
current_count = int(current) if current else 0
# Check if limit exceeded
if current_count >= limit:
ttl = await self.redis.ttl(key)
reset_at = datetime.utcnow() + timedelta(seconds=ttl if ttl > 0 else period)
self.logger.warning(
"quota_exceeded",
tenant_id=tenant_id,
quota_type=quota_type,
current=current_count,
limit=limit,
reset_at=reset_at.isoformat()
)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail={
"error": "quota_exceeded",
"message": f"Daily quota exceeded for {quota_type}",
"current": current_count,
"limit": limit,
"reset_at": reset_at.isoformat(),
"quota_type": quota_type
}
)
# Increment counter
pipe = self.redis.pipeline()
pipe.incr(key)
pipe.expire(key, period)
await pipe.execute()
new_count = current_count + 1
ttl = await self.redis.ttl(key)
reset_at = datetime.utcnow() + timedelta(seconds=ttl if ttl > 0 else period)
self.logger.info(
"quota_incremented",
tenant_id=tenant_id,
quota_type=quota_type,
current=new_count,
limit=limit
)
return {
"allowed": True,
"current": new_count,
"limit": limit,
"reset_at": reset_at
}
except HTTPException:
raise
except Exception as e:
self.logger.error(
"quota_check_failed",
error=str(e),
tenant_id=tenant_id,
quota_type=quota_type
)
# Fail open - allow the operation
return {
"allowed": True,
"current": 0,
"limit": limit,
"reset_at": None
}
async def get_current_usage(
self,
tenant_id: str,
quota_type: str
) -> Dict[str, Any]:
"""
Get current quota usage without incrementing
Args:
tenant_id: Tenant ID
quota_type: Type of quota to check
Returns:
Dict with current usage information
"""
key = self._get_quota_key(tenant_id, quota_type)
try:
current = await self.redis.get(key)
current_count = int(current) if current else 0
ttl = await self.redis.ttl(key)
reset_at = datetime.utcnow() + timedelta(seconds=ttl) if ttl > 0 else None
return {
"current": current_count,
"reset_at": reset_at
}
except Exception as e:
self.logger.error(
"usage_check_failed",
error=str(e),
tenant_id=tenant_id,
quota_type=quota_type
)
return {
"current": 0,
"reset_at": None
}
async def reset_quota(self, tenant_id: str, quota_type: str):
"""
Reset quota for a tenant (admin function)
Args:
tenant_id: Tenant ID
quota_type: Type of quota to reset
"""
key = self._get_quota_key(tenant_id, quota_type)
try:
await self.redis.delete(key)
self.logger.info(
"quota_reset",
tenant_id=tenant_id,
quota_type=quota_type
)
except Exception as e:
self.logger.error(
"quota_reset_failed",
error=str(e),
tenant_id=tenant_id,
quota_type=quota_type
)
async def validate_dataset_size(
self,
tenant_id: str,
dataset_rows: int,
subscription_tier: str
):
"""
Validate dataset size against subscription tier limits
Args:
tenant_id: Tenant ID
dataset_rows: Number of rows in dataset
subscription_tier: User's subscription tier
Raises:
HTTPException: If dataset size exceeds limit
"""
# Dataset size limits per tier
dataset_limits = {
'starter': 1000,
'professional': 10000,
'enterprise': None # Unlimited
}
limit = dataset_limits.get(subscription_tier.lower())
if limit is not None and dataset_rows > limit:
self.logger.warning(
"dataset_size_exceeded",
tenant_id=tenant_id,
dataset_rows=dataset_rows,
limit=limit,
tier=subscription_tier
)
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail={
"error": "dataset_size_limit_exceeded",
"message": f"Dataset size limited to {limit:,} rows for {subscription_tier} tier",
"current_size": dataset_rows,
"limit": limit,
"tier": subscription_tier,
"upgrade_url": "/app/settings/profile"
}
)
self.logger.info(
"dataset_size_validated",
tenant_id=tenant_id,
dataset_rows=dataset_rows,
tier=subscription_tier
)
async def validate_forecast_horizon(
self,
tenant_id: str,
horizon_days: int,
subscription_tier: str
):
"""
Validate forecast horizon against subscription tier limits
Args:
tenant_id: Tenant ID
horizon_days: Number of days to forecast
subscription_tier: User's subscription tier
Raises:
HTTPException: If horizon exceeds limit
"""
# Forecast horizon limits per tier
horizon_limits = {
'starter': 7,
'professional': 90,
'enterprise': 365 # Practically unlimited
}
limit = horizon_limits.get(subscription_tier.lower(), 7)
if horizon_days > limit:
self.logger.warning(
"forecast_horizon_exceeded",
tenant_id=tenant_id,
horizon_days=horizon_days,
limit=limit,
tier=subscription_tier
)
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail={
"error": "forecast_horizon_limit_exceeded",
"message": f"Forecast horizon limited to {limit} days for {subscription_tier} tier",
"requested_horizon": horizon_days,
"limit": limit,
"tier": subscription_tier,
"upgrade_url": "/app/settings/profile"
}
)
self.logger.info(
"forecast_horizon_validated",
tenant_id=tenant_id,
horizon_days=horizon_days,
tier=subscription_tier
)
async def validate_historical_data_access(
self,
tenant_id: str,
days_back: int,
subscription_tier: str
):
"""
Validate historical data access against subscription tier limits
Args:
tenant_id: Tenant ID
days_back: Number of days of historical data requested
subscription_tier: User's subscription tier
Raises:
HTTPException: If historical data access exceeds limit
"""
# Historical data limits per tier
history_limits = {
'starter': 7,
'professional': 90,
'enterprise': None # Unlimited
}
limit = history_limits.get(subscription_tier.lower(), 7)
if limit is not None and days_back > limit:
self.logger.warning(
"historical_data_limit_exceeded",
tenant_id=tenant_id,
days_back=days_back,
limit=limit,
tier=subscription_tier
)
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail={
"error": "historical_data_limit_exceeded",
"message": f"Historical data limited to {limit} days for {subscription_tier} tier",
"requested_days": days_back,
"limit": limit,
"tier": subscription_tier,
"upgrade_url": "/app/settings/profile"
}
)
self.logger.info(
"historical_data_access_validated",
tenant_id=tenant_id,
days_back=days_back,
tier=subscription_tier
)
def create_rate_limiter(redis_client) -> RateLimiter:
"""Factory function to create rate limiter"""
return RateLimiter(redis_client)