""" Rate limiting and quota management system for subscription-based features """ import time from typing import Optional, Dict, Any from datetime import datetime, timedelta from enum import Enum import structlog from fastapi import HTTPException, status logger = structlog.get_logger() class QuotaType(str, Enum): """Types of quotas""" FORECAST_GENERATION = "forecast_generation" TRAINING_JOBS = "training_jobs" BULK_IMPORTS = "bulk_imports" POS_SYNC = "pos_sync" API_CALLS = "api_calls" DEMO_SESSIONS = "demo_sessions" class RateLimiter: """ Redis-based rate limiter for subscription tier quotas """ def __init__(self, redis_client): """ Initialize rate limiter Args: redis_client: Redis client for storing quota counters """ self.redis = redis_client self.logger = logger def _get_quota_key(self, tenant_id: str, quota_type: str, period: str = "daily") -> str: """Generate Redis key for quota tracking""" date_str = datetime.utcnow().strftime("%Y-%m-%d") return f"quota:{period}:{quota_type}:{tenant_id}:{date_str}" def _get_dataset_size_key(self, tenant_id: str) -> str: """Generate Redis key for dataset size tracking""" return f"dataset_size:{tenant_id}" async def check_and_increment_quota( self, tenant_id: str, quota_type: str, limit: Optional[int], period: int = 86400 # 24 hours in seconds ) -> Dict[str, Any]: """ Check if quota allows action and increment counter Args: tenant_id: Tenant ID quota_type: Type of quota to check limit: Maximum allowed count (None = unlimited) period: Time period in seconds (default: 24 hours) Returns: Dict with: - allowed: bool - current: int (current count) - limit: Optional[int] - reset_at: datetime (when quota resets) Raises: HTTPException: If quota is exceeded """ if limit is None: # Unlimited quota return { "allowed": True, "current": 0, "limit": None, "reset_at": None } key = self._get_quota_key(tenant_id, quota_type) try: # Get current count current = await self.redis.get(key) current_count = int(current) if current else 0 # Check if limit exceeded if current_count >= limit: ttl = await self.redis.ttl(key) reset_at = datetime.utcnow() + timedelta(seconds=ttl if ttl > 0 else period) self.logger.warning( "quota_exceeded", tenant_id=tenant_id, quota_type=quota_type, current=current_count, limit=limit, reset_at=reset_at.isoformat() ) raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail={ "error": "quota_exceeded", "message": f"Daily quota exceeded for {quota_type}", "current": current_count, "limit": limit, "reset_at": reset_at.isoformat(), "quota_type": quota_type } ) # Increment counter pipe = self.redis.pipeline() pipe.incr(key) pipe.expire(key, period) await pipe.execute() new_count = current_count + 1 ttl = await self.redis.ttl(key) reset_at = datetime.utcnow() + timedelta(seconds=ttl if ttl > 0 else period) self.logger.info( "quota_incremented", tenant_id=tenant_id, quota_type=quota_type, current=new_count, limit=limit ) return { "allowed": True, "current": new_count, "limit": limit, "reset_at": reset_at } except HTTPException: raise except Exception as e: self.logger.error( "quota_check_failed", error=str(e), tenant_id=tenant_id, quota_type=quota_type ) # Fail open - allow the operation return { "allowed": True, "current": 0, "limit": limit, "reset_at": None } async def get_current_usage( self, tenant_id: str, quota_type: str ) -> Dict[str, Any]: """ Get current quota usage without incrementing Args: tenant_id: Tenant ID quota_type: Type of quota to check Returns: Dict with current usage information """ key = self._get_quota_key(tenant_id, quota_type) try: current = await self.redis.get(key) current_count = int(current) if current else 0 ttl = await self.redis.ttl(key) reset_at = datetime.utcnow() + timedelta(seconds=ttl) if ttl > 0 else None return { "current": current_count, "reset_at": reset_at } except Exception as e: self.logger.error( "usage_check_failed", error=str(e), tenant_id=tenant_id, quota_type=quota_type ) return { "current": 0, "reset_at": None } async def reset_quota(self, tenant_id: str, quota_type: str): """ Reset quota for a tenant (admin function) Args: tenant_id: Tenant ID quota_type: Type of quota to reset """ key = self._get_quota_key(tenant_id, quota_type) try: await self.redis.delete(key) self.logger.info( "quota_reset", tenant_id=tenant_id, quota_type=quota_type ) except Exception as e: self.logger.error( "quota_reset_failed", error=str(e), tenant_id=tenant_id, quota_type=quota_type ) async def validate_dataset_size( self, tenant_id: str, dataset_rows: int, subscription_tier: str ): """ Validate dataset size against subscription tier limits Args: tenant_id: Tenant ID dataset_rows: Number of rows in dataset subscription_tier: User's subscription tier Raises: HTTPException: If dataset size exceeds limit """ # Dataset size limits per tier dataset_limits = { 'starter': 1000, 'professional': 10000, 'enterprise': None # Unlimited } limit = dataset_limits.get(subscription_tier.lower()) if limit is not None and dataset_rows > limit: self.logger.warning( "dataset_size_exceeded", tenant_id=tenant_id, dataset_rows=dataset_rows, limit=limit, tier=subscription_tier ) raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail={ "error": "dataset_size_limit_exceeded", "message": f"Dataset size limited to {limit:,} rows for {subscription_tier} tier", "current_size": dataset_rows, "limit": limit, "tier": subscription_tier, "upgrade_url": "/app/settings/profile" } ) self.logger.info( "dataset_size_validated", tenant_id=tenant_id, dataset_rows=dataset_rows, tier=subscription_tier ) async def validate_forecast_horizon( self, tenant_id: str, horizon_days: int, subscription_tier: str ): """ Validate forecast horizon against subscription tier limits Args: tenant_id: Tenant ID horizon_days: Number of days to forecast subscription_tier: User's subscription tier Raises: HTTPException: If horizon exceeds limit """ # Forecast horizon limits per tier horizon_limits = { 'starter': 7, 'professional': 90, 'enterprise': 365 # Practically unlimited } limit = horizon_limits.get(subscription_tier.lower(), 7) if horizon_days > limit: self.logger.warning( "forecast_horizon_exceeded", tenant_id=tenant_id, horizon_days=horizon_days, limit=limit, tier=subscription_tier ) raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail={ "error": "forecast_horizon_limit_exceeded", "message": f"Forecast horizon limited to {limit} days for {subscription_tier} tier", "requested_horizon": horizon_days, "limit": limit, "tier": subscription_tier, "upgrade_url": "/app/settings/profile" } ) self.logger.info( "forecast_horizon_validated", tenant_id=tenant_id, horizon_days=horizon_days, tier=subscription_tier ) async def validate_historical_data_access( self, tenant_id: str, days_back: int, subscription_tier: str ): """ Validate historical data access against subscription tier limits Args: tenant_id: Tenant ID days_back: Number of days of historical data requested subscription_tier: User's subscription tier Raises: HTTPException: If historical data access exceeds limit """ # Historical data limits per tier history_limits = { 'starter': 7, 'professional': 90, 'enterprise': None # Unlimited } limit = history_limits.get(subscription_tier.lower(), 7) if limit is not None and days_back > limit: self.logger.warning( "historical_data_limit_exceeded", tenant_id=tenant_id, days_back=days_back, limit=limit, tier=subscription_tier ) raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail={ "error": "historical_data_limit_exceeded", "message": f"Historical data limited to {limit} days for {subscription_tier} tier", "requested_days": days_back, "limit": limit, "tier": subscription_tier, "upgrade_url": "/app/settings/profile" } ) self.logger.info( "historical_data_access_validated", tenant_id=tenant_id, days_back=days_back, tier=subscription_tier ) def create_rate_limiter(redis_client) -> RateLimiter: """Factory function to create rate limiter""" return RateLimiter(redis_client)