389 lines
12 KiB
Python
Executable File
389 lines
12 KiB
Python
Executable File
"""
|
|
Rate limiting and quota management system for subscription-based features
|
|
"""
|
|
|
|
import time
|
|
from typing import Optional, Dict, Any
|
|
from datetime import datetime, timedelta
|
|
from enum import Enum
|
|
import structlog
|
|
from fastapi import HTTPException, status
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class QuotaType(str, Enum):
|
|
"""Types of quotas"""
|
|
FORECAST_GENERATION = "forecast_generation"
|
|
TRAINING_JOBS = "training_jobs"
|
|
BULK_IMPORTS = "bulk_imports"
|
|
POS_SYNC = "pos_sync"
|
|
API_CALLS = "api_calls"
|
|
DEMO_SESSIONS = "demo_sessions"
|
|
|
|
|
|
class RateLimiter:
|
|
"""
|
|
Redis-based rate limiter for subscription tier quotas
|
|
"""
|
|
|
|
def __init__(self, redis_client):
|
|
"""
|
|
Initialize rate limiter
|
|
|
|
Args:
|
|
redis_client: Redis client for storing quota counters
|
|
"""
|
|
self.redis = redis_client
|
|
self.logger = logger
|
|
|
|
def _get_quota_key(self, tenant_id: str, quota_type: str, period: str = "daily") -> str:
|
|
"""Generate Redis key for quota tracking"""
|
|
date_str = datetime.utcnow().strftime("%Y-%m-%d")
|
|
return f"quota:{period}:{quota_type}:{tenant_id}:{date_str}"
|
|
|
|
def _get_dataset_size_key(self, tenant_id: str) -> str:
|
|
"""Generate Redis key for dataset size tracking"""
|
|
return f"dataset_size:{tenant_id}"
|
|
|
|
async def check_and_increment_quota(
|
|
self,
|
|
tenant_id: str,
|
|
quota_type: str,
|
|
limit: Optional[int],
|
|
period: int = 86400 # 24 hours in seconds
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Check if quota allows action and increment counter
|
|
|
|
Args:
|
|
tenant_id: Tenant ID
|
|
quota_type: Type of quota to check
|
|
limit: Maximum allowed count (None = unlimited)
|
|
period: Time period in seconds (default: 24 hours)
|
|
|
|
Returns:
|
|
Dict with:
|
|
- allowed: bool
|
|
- current: int (current count)
|
|
- limit: Optional[int]
|
|
- reset_at: datetime (when quota resets)
|
|
|
|
Raises:
|
|
HTTPException: If quota is exceeded
|
|
"""
|
|
if limit is None:
|
|
# Unlimited quota
|
|
return {
|
|
"allowed": True,
|
|
"current": 0,
|
|
"limit": None,
|
|
"reset_at": None
|
|
}
|
|
|
|
key = self._get_quota_key(tenant_id, quota_type)
|
|
|
|
try:
|
|
# Get current count
|
|
current = await self.redis.get(key)
|
|
current_count = int(current) if current else 0
|
|
|
|
# Check if limit exceeded
|
|
if current_count >= limit:
|
|
ttl = await self.redis.ttl(key)
|
|
reset_at = datetime.utcnow() + timedelta(seconds=ttl if ttl > 0 else period)
|
|
|
|
self.logger.warning(
|
|
"quota_exceeded",
|
|
tenant_id=tenant_id,
|
|
quota_type=quota_type,
|
|
current=current_count,
|
|
limit=limit,
|
|
reset_at=reset_at.isoformat()
|
|
)
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail={
|
|
"error": "quota_exceeded",
|
|
"message": f"Daily quota exceeded for {quota_type}",
|
|
"current": current_count,
|
|
"limit": limit,
|
|
"reset_at": reset_at.isoformat(),
|
|
"quota_type": quota_type
|
|
}
|
|
)
|
|
|
|
# Increment counter
|
|
pipe = self.redis.pipeline()
|
|
pipe.incr(key)
|
|
pipe.expire(key, period)
|
|
await pipe.execute()
|
|
|
|
new_count = current_count + 1
|
|
ttl = await self.redis.ttl(key)
|
|
reset_at = datetime.utcnow() + timedelta(seconds=ttl if ttl > 0 else period)
|
|
|
|
self.logger.info(
|
|
"quota_incremented",
|
|
tenant_id=tenant_id,
|
|
quota_type=quota_type,
|
|
current=new_count,
|
|
limit=limit
|
|
)
|
|
|
|
return {
|
|
"allowed": True,
|
|
"current": new_count,
|
|
"limit": limit,
|
|
"reset_at": reset_at
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
self.logger.error(
|
|
"quota_check_failed",
|
|
error=str(e),
|
|
tenant_id=tenant_id,
|
|
quota_type=quota_type
|
|
)
|
|
# Fail open - allow the operation
|
|
return {
|
|
"allowed": True,
|
|
"current": 0,
|
|
"limit": limit,
|
|
"reset_at": None
|
|
}
|
|
|
|
async def get_current_usage(
|
|
self,
|
|
tenant_id: str,
|
|
quota_type: str
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Get current quota usage without incrementing
|
|
|
|
Args:
|
|
tenant_id: Tenant ID
|
|
quota_type: Type of quota to check
|
|
|
|
Returns:
|
|
Dict with current usage information
|
|
"""
|
|
key = self._get_quota_key(tenant_id, quota_type)
|
|
|
|
try:
|
|
current = await self.redis.get(key)
|
|
current_count = int(current) if current else 0
|
|
ttl = await self.redis.ttl(key)
|
|
reset_at = datetime.utcnow() + timedelta(seconds=ttl) if ttl > 0 else None
|
|
|
|
return {
|
|
"current": current_count,
|
|
"reset_at": reset_at
|
|
}
|
|
|
|
except Exception as e:
|
|
self.logger.error(
|
|
"usage_check_failed",
|
|
error=str(e),
|
|
tenant_id=tenant_id,
|
|
quota_type=quota_type
|
|
)
|
|
return {
|
|
"current": 0,
|
|
"reset_at": None
|
|
}
|
|
|
|
async def reset_quota(self, tenant_id: str, quota_type: str):
|
|
"""
|
|
Reset quota for a tenant (admin function)
|
|
|
|
Args:
|
|
tenant_id: Tenant ID
|
|
quota_type: Type of quota to reset
|
|
"""
|
|
key = self._get_quota_key(tenant_id, quota_type)
|
|
try:
|
|
await self.redis.delete(key)
|
|
self.logger.info(
|
|
"quota_reset",
|
|
tenant_id=tenant_id,
|
|
quota_type=quota_type
|
|
)
|
|
except Exception as e:
|
|
self.logger.error(
|
|
"quota_reset_failed",
|
|
error=str(e),
|
|
tenant_id=tenant_id,
|
|
quota_type=quota_type
|
|
)
|
|
|
|
async def validate_dataset_size(
|
|
self,
|
|
tenant_id: str,
|
|
dataset_rows: int,
|
|
subscription_tier: str
|
|
):
|
|
"""
|
|
Validate dataset size against subscription tier limits
|
|
|
|
Args:
|
|
tenant_id: Tenant ID
|
|
dataset_rows: Number of rows in dataset
|
|
subscription_tier: User's subscription tier
|
|
|
|
Raises:
|
|
HTTPException: If dataset size exceeds limit
|
|
"""
|
|
# Dataset size limits per tier
|
|
dataset_limits = {
|
|
'starter': 1000,
|
|
'professional': 10000,
|
|
'enterprise': None # Unlimited
|
|
}
|
|
|
|
limit = dataset_limits.get(subscription_tier.lower())
|
|
|
|
if limit is not None and dataset_rows > limit:
|
|
self.logger.warning(
|
|
"dataset_size_exceeded",
|
|
tenant_id=tenant_id,
|
|
dataset_rows=dataset_rows,
|
|
limit=limit,
|
|
tier=subscription_tier
|
|
)
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
detail={
|
|
"error": "dataset_size_limit_exceeded",
|
|
"message": f"Dataset size limited to {limit:,} rows for {subscription_tier} tier",
|
|
"current_size": dataset_rows,
|
|
"limit": limit,
|
|
"tier": subscription_tier,
|
|
"upgrade_url": "/app/settings/profile"
|
|
}
|
|
)
|
|
|
|
self.logger.info(
|
|
"dataset_size_validated",
|
|
tenant_id=tenant_id,
|
|
dataset_rows=dataset_rows,
|
|
tier=subscription_tier
|
|
)
|
|
|
|
async def validate_forecast_horizon(
|
|
self,
|
|
tenant_id: str,
|
|
horizon_days: int,
|
|
subscription_tier: str
|
|
):
|
|
"""
|
|
Validate forecast horizon against subscription tier limits
|
|
|
|
Args:
|
|
tenant_id: Tenant ID
|
|
horizon_days: Number of days to forecast
|
|
subscription_tier: User's subscription tier
|
|
|
|
Raises:
|
|
HTTPException: If horizon exceeds limit
|
|
"""
|
|
# Forecast horizon limits per tier
|
|
horizon_limits = {
|
|
'starter': 7,
|
|
'professional': 90,
|
|
'enterprise': 365 # Practically unlimited
|
|
}
|
|
|
|
limit = horizon_limits.get(subscription_tier.lower(), 7)
|
|
|
|
if horizon_days > limit:
|
|
self.logger.warning(
|
|
"forecast_horizon_exceeded",
|
|
tenant_id=tenant_id,
|
|
horizon_days=horizon_days,
|
|
limit=limit,
|
|
tier=subscription_tier
|
|
)
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
detail={
|
|
"error": "forecast_horizon_limit_exceeded",
|
|
"message": f"Forecast horizon limited to {limit} days for {subscription_tier} tier",
|
|
"requested_horizon": horizon_days,
|
|
"limit": limit,
|
|
"tier": subscription_tier,
|
|
"upgrade_url": "/app/settings/profile"
|
|
}
|
|
)
|
|
|
|
self.logger.info(
|
|
"forecast_horizon_validated",
|
|
tenant_id=tenant_id,
|
|
horizon_days=horizon_days,
|
|
tier=subscription_tier
|
|
)
|
|
|
|
async def validate_historical_data_access(
|
|
self,
|
|
tenant_id: str,
|
|
days_back: int,
|
|
subscription_tier: str
|
|
):
|
|
"""
|
|
Validate historical data access against subscription tier limits
|
|
|
|
Args:
|
|
tenant_id: Tenant ID
|
|
days_back: Number of days of historical data requested
|
|
subscription_tier: User's subscription tier
|
|
|
|
Raises:
|
|
HTTPException: If historical data access exceeds limit
|
|
"""
|
|
# Historical data limits per tier
|
|
history_limits = {
|
|
'starter': 7,
|
|
'professional': 90,
|
|
'enterprise': None # Unlimited
|
|
}
|
|
|
|
limit = history_limits.get(subscription_tier.lower(), 7)
|
|
|
|
if limit is not None and days_back > limit:
|
|
self.logger.warning(
|
|
"historical_data_limit_exceeded",
|
|
tenant_id=tenant_id,
|
|
days_back=days_back,
|
|
limit=limit,
|
|
tier=subscription_tier
|
|
)
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
detail={
|
|
"error": "historical_data_limit_exceeded",
|
|
"message": f"Historical data limited to {limit} days for {subscription_tier} tier",
|
|
"requested_days": days_back,
|
|
"limit": limit,
|
|
"tier": subscription_tier,
|
|
"upgrade_url": "/app/settings/profile"
|
|
}
|
|
)
|
|
|
|
self.logger.info(
|
|
"historical_data_access_validated",
|
|
tenant_id=tenant_id,
|
|
days_back=days_back,
|
|
tier=subscription_tier
|
|
)
|
|
|
|
|
|
def create_rate_limiter(redis_client) -> RateLimiter:
|
|
"""Factory function to create rate limiter"""
|
|
return RateLimiter(redis_client)
|