Add role-based filtering and imporve code

This commit is contained in:
Urtzi Alfaro
2025-10-15 16:12:49 +02:00
parent 96ad5c6692
commit 8f9e9a7edc
158 changed files with 11033 additions and 1544 deletions

View File

@@ -0,0 +1,31 @@
"""
Security utilities for RBAC, audit logging, and rate limiting
"""
from shared.security.audit_logger import (
AuditLogger,
AuditSeverity,
AuditAction,
create_audit_logger,
create_audit_log_model
)
from shared.security.rate_limiter import (
RateLimiter,
QuotaType,
create_rate_limiter
)
__all__ = [
# Audit logging
"AuditLogger",
"AuditSeverity",
"AuditAction",
"create_audit_logger",
"create_audit_log_model",
# Rate limiting
"RateLimiter",
"QuotaType",
"create_rate_limiter",
]

View File

@@ -0,0 +1,317 @@
"""
Audit logging system for tracking critical operations across all services
"""
import uuid
from datetime import datetime, timezone
from typing import Optional, Dict, Any
from enum import Enum
import structlog
from sqlalchemy import Column, String, DateTime, Text, Index
from sqlalchemy.dialects.postgresql import UUID, JSON
logger = structlog.get_logger()
class AuditSeverity(str, Enum):
"""Severity levels for audit events"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class AuditAction(str, Enum):
"""Common audit action types"""
CREATE = "create"
READ = "read"
UPDATE = "update"
DELETE = "delete"
APPROVE = "approve"
REJECT = "reject"
CANCEL = "cancel"
EXPORT = "export"
IMPORT = "import"
INVITE = "invite"
REMOVE = "remove"
UPGRADE = "upgrade"
DOWNGRADE = "downgrade"
DEACTIVATE = "deactivate"
ACTIVATE = "activate"
def create_audit_log_model(Base):
"""
Factory function to create AuditLog model for any service
Each service has its own audit_logs table in their database
Usage in service models/__init__.py:
from shared.database.base import Base
from shared.security import create_audit_log_model
AuditLog = create_audit_log_model(Base)
Args:
Base: SQLAlchemy declarative base for the service
Returns:
AuditLog model class bound to the service's Base
"""
class AuditLog(Base):
"""
Audit log model for tracking critical operations
Each service has its own audit_logs table for data locality
"""
__tablename__ = "audit_logs"
# Primary identification
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Tenant and user context
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Action details
action = Column(String(100), nullable=False, index=True) # create, update, delete, etc.
resource_type = Column(String(100), nullable=False, index=True) # supplier, recipe, order, etc.
resource_id = Column(String(255), nullable=True, index=True)
# Severity and categorization
severity = Column(
String(20),
nullable=False,
default="medium",
index=True
) # low, medium, high, critical
# Service identification
service_name = Column(String(100), nullable=False, index=True)
# Details
description = Column(Text, nullable=True)
# Audit trail data
changes = Column(JSON, nullable=True) # Before/after values for updates
audit_metadata = Column(JSON, nullable=True) # Additional context
# Request context
ip_address = Column(String(45), nullable=True) # IPv4 or IPv6
user_agent = Column(Text, nullable=True)
endpoint = Column(String(255), nullable=True)
method = Column(String(10), nullable=True) # GET, POST, PUT, DELETE
# Timestamps
created_at = Column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
index=True
)
# Composite indexes for common query patterns
__table_args__ = (
Index('idx_audit_tenant_created', 'tenant_id', 'created_at'),
Index('idx_audit_user_created', 'user_id', 'created_at'),
Index('idx_audit_resource_type_action', 'resource_type', 'action'),
Index('idx_audit_severity_created', 'severity', 'created_at'),
Index('idx_audit_service_created', 'service_name', 'created_at'),
)
def __repr__(self):
return (
f"<AuditLog(id={self.id}, tenant={self.tenant_id}, "
f"action={self.action}, resource={self.resource_type}, "
f"severity={self.severity})>"
)
def to_dict(self):
"""Convert audit log to dictionary"""
return {
"id": str(self.id),
"tenant_id": str(self.tenant_id),
"user_id": str(self.user_id),
"action": self.action,
"resource_type": self.resource_type,
"resource_id": self.resource_id,
"severity": self.severity,
"service_name": self.service_name,
"description": self.description,
"changes": self.changes,
"metadata": self.audit_metadata,
"ip_address": self.ip_address,
"user_agent": self.user_agent,
"endpoint": self.endpoint,
"method": self.method,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
return AuditLog
class AuditLogger:
"""Service for logging audit events"""
def __init__(self, service_name: str):
self.service_name = service_name
self.logger = logger.bind(service=service_name)
async def log_event(
self,
db_session,
tenant_id: str,
user_id: str,
action: str,
resource_type: str,
resource_id: Optional[str] = None,
severity: str = "medium",
description: Optional[str] = None,
changes: Optional[Dict[str, Any]] = None,
audit_metadata: Optional[Dict[str, Any]] = None,
endpoint: Optional[str] = None,
method: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
):
"""
Log an audit event
Args:
db_session: Database session
tenant_id: Tenant ID
user_id: User ID who performed the action
action: Action performed (create, update, delete, etc.)
resource_type: Type of resource (user, sale, recipe, etc.)
resource_id: ID of the resource affected
severity: Severity level (low, medium, high, critical)
description: Human-readable description
changes: Dictionary of before/after values for updates
audit_metadata: Additional context
endpoint: API endpoint
method: HTTP method
ip_address: Client IP address
user_agent: Client user agent
"""
try:
audit_log = AuditLog(
tenant_id=uuid.UUID(tenant_id) if isinstance(tenant_id, str) else tenant_id,
user_id=uuid.UUID(user_id) if isinstance(user_id, str) else user_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
severity=severity,
service_name=self.service_name,
description=description,
changes=changes,
audit_metadata=audit_metadata,
endpoint=endpoint,
method=method,
ip_address=ip_address,
user_agent=user_agent,
)
db_session.add(audit_log)
await db_session.commit()
self.logger.info(
"audit_event_logged",
tenant_id=str(tenant_id),
user_id=str(user_id),
action=action,
resource_type=resource_type,
resource_id=resource_id,
severity=severity,
)
except Exception as e:
self.logger.error(
"audit_log_failed",
error=str(e),
tenant_id=str(tenant_id),
user_id=str(user_id),
action=action,
)
# Don't raise - audit logging should not block operations
async def log_deletion(
self,
db_session,
tenant_id: str,
user_id: str,
resource_type: str,
resource_id: str,
resource_data: Optional[Dict[str, Any]] = None,
**kwargs
):
"""Convenience method for logging deletions"""
return await self.log_event(
db_session=db_session,
tenant_id=tenant_id,
user_id=user_id,
action=AuditAction.DELETE.value,
resource_type=resource_type,
resource_id=resource_id,
severity=AuditSeverity.HIGH.value,
description=f"Deleted {resource_type} {resource_id}",
audit_metadata={"deleted_data": resource_data} if resource_data else None,
**kwargs
)
async def log_role_change(
self,
db_session,
tenant_id: str,
user_id: str,
target_user_id: str,
old_role: str,
new_role: str,
**kwargs
):
"""Convenience method for logging role changes"""
return await self.log_event(
db_session=db_session,
tenant_id=tenant_id,
user_id=user_id,
action=AuditAction.UPDATE.value,
resource_type="user_role",
resource_id=target_user_id,
severity=AuditSeverity.HIGH.value,
description=f"Changed user role from {old_role} to {new_role}",
changes={
"before": {"role": old_role},
"after": {"role": new_role}
},
**kwargs
)
async def log_subscription_change(
self,
db_session,
tenant_id: str,
user_id: str,
action: str,
old_plan: Optional[str] = None,
new_plan: Optional[str] = None,
**kwargs
):
"""Convenience method for logging subscription changes"""
return await self.log_event(
db_session=db_session,
tenant_id=tenant_id,
user_id=user_id,
action=action,
resource_type="subscription",
resource_id=tenant_id,
severity=AuditSeverity.CRITICAL.value,
description=f"Subscription {action}: {old_plan} -> {new_plan}" if old_plan else f"Subscription {action}: {new_plan}",
changes={
"before": {"plan": old_plan} if old_plan else None,
"after": {"plan": new_plan} if new_plan else None
},
**kwargs
)
def create_audit_logger(service_name: str) -> AuditLogger:
"""Factory function to create audit logger for a service"""
return AuditLogger(service_name)

View File

@@ -0,0 +1,388 @@
"""
Rate limiting and quota management system for subscription-based features
"""
import time
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
from enum import Enum
import structlog
from fastapi import HTTPException, status
logger = structlog.get_logger()
class QuotaType(str, Enum):
"""Types of quotas"""
FORECAST_GENERATION = "forecast_generation"
TRAINING_JOBS = "training_jobs"
BULK_IMPORTS = "bulk_imports"
POS_SYNC = "pos_sync"
API_CALLS = "api_calls"
DEMO_SESSIONS = "demo_sessions"
class RateLimiter:
"""
Redis-based rate limiter for subscription tier quotas
"""
def __init__(self, redis_client):
"""
Initialize rate limiter
Args:
redis_client: Redis client for storing quota counters
"""
self.redis = redis_client
self.logger = logger
def _get_quota_key(self, tenant_id: str, quota_type: str, period: str = "daily") -> str:
"""Generate Redis key for quota tracking"""
date_str = datetime.utcnow().strftime("%Y-%m-%d")
return f"quota:{period}:{quota_type}:{tenant_id}:{date_str}"
def _get_dataset_size_key(self, tenant_id: str) -> str:
"""Generate Redis key for dataset size tracking"""
return f"dataset_size:{tenant_id}"
async def check_and_increment_quota(
self,
tenant_id: str,
quota_type: str,
limit: Optional[int],
period: int = 86400 # 24 hours in seconds
) -> Dict[str, Any]:
"""
Check if quota allows action and increment counter
Args:
tenant_id: Tenant ID
quota_type: Type of quota to check
limit: Maximum allowed count (None = unlimited)
period: Time period in seconds (default: 24 hours)
Returns:
Dict with:
- allowed: bool
- current: int (current count)
- limit: Optional[int]
- reset_at: datetime (when quota resets)
Raises:
HTTPException: If quota is exceeded
"""
if limit is None:
# Unlimited quota
return {
"allowed": True,
"current": 0,
"limit": None,
"reset_at": None
}
key = self._get_quota_key(tenant_id, quota_type)
try:
# Get current count
current = await self.redis.get(key)
current_count = int(current) if current else 0
# Check if limit exceeded
if current_count >= limit:
ttl = await self.redis.ttl(key)
reset_at = datetime.utcnow() + timedelta(seconds=ttl if ttl > 0 else period)
self.logger.warning(
"quota_exceeded",
tenant_id=tenant_id,
quota_type=quota_type,
current=current_count,
limit=limit,
reset_at=reset_at.isoformat()
)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail={
"error": "quota_exceeded",
"message": f"Daily quota exceeded for {quota_type}",
"current": current_count,
"limit": limit,
"reset_at": reset_at.isoformat(),
"quota_type": quota_type
}
)
# Increment counter
pipe = self.redis.pipeline()
pipe.incr(key)
pipe.expire(key, period)
await pipe.execute()
new_count = current_count + 1
ttl = await self.redis.ttl(key)
reset_at = datetime.utcnow() + timedelta(seconds=ttl if ttl > 0 else period)
self.logger.info(
"quota_incremented",
tenant_id=tenant_id,
quota_type=quota_type,
current=new_count,
limit=limit
)
return {
"allowed": True,
"current": new_count,
"limit": limit,
"reset_at": reset_at
}
except HTTPException:
raise
except Exception as e:
self.logger.error(
"quota_check_failed",
error=str(e),
tenant_id=tenant_id,
quota_type=quota_type
)
# Fail open - allow the operation
return {
"allowed": True,
"current": 0,
"limit": limit,
"reset_at": None
}
async def get_current_usage(
self,
tenant_id: str,
quota_type: str
) -> Dict[str, Any]:
"""
Get current quota usage without incrementing
Args:
tenant_id: Tenant ID
quota_type: Type of quota to check
Returns:
Dict with current usage information
"""
key = self._get_quota_key(tenant_id, quota_type)
try:
current = await self.redis.get(key)
current_count = int(current) if current else 0
ttl = await self.redis.ttl(key)
reset_at = datetime.utcnow() + timedelta(seconds=ttl) if ttl > 0 else None
return {
"current": current_count,
"reset_at": reset_at
}
except Exception as e:
self.logger.error(
"usage_check_failed",
error=str(e),
tenant_id=tenant_id,
quota_type=quota_type
)
return {
"current": 0,
"reset_at": None
}
async def reset_quota(self, tenant_id: str, quota_type: str):
"""
Reset quota for a tenant (admin function)
Args:
tenant_id: Tenant ID
quota_type: Type of quota to reset
"""
key = self._get_quota_key(tenant_id, quota_type)
try:
await self.redis.delete(key)
self.logger.info(
"quota_reset",
tenant_id=tenant_id,
quota_type=quota_type
)
except Exception as e:
self.logger.error(
"quota_reset_failed",
error=str(e),
tenant_id=tenant_id,
quota_type=quota_type
)
async def validate_dataset_size(
self,
tenant_id: str,
dataset_rows: int,
subscription_tier: str
):
"""
Validate dataset size against subscription tier limits
Args:
tenant_id: Tenant ID
dataset_rows: Number of rows in dataset
subscription_tier: User's subscription tier
Raises:
HTTPException: If dataset size exceeds limit
"""
# Dataset size limits per tier
dataset_limits = {
'starter': 1000,
'professional': 10000,
'enterprise': None # Unlimited
}
limit = dataset_limits.get(subscription_tier.lower())
if limit is not None and dataset_rows > limit:
self.logger.warning(
"dataset_size_exceeded",
tenant_id=tenant_id,
dataset_rows=dataset_rows,
limit=limit,
tier=subscription_tier
)
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail={
"error": "dataset_size_limit_exceeded",
"message": f"Dataset size limited to {limit:,} rows for {subscription_tier} tier",
"current_size": dataset_rows,
"limit": limit,
"tier": subscription_tier,
"upgrade_url": "/app/settings/profile"
}
)
self.logger.info(
"dataset_size_validated",
tenant_id=tenant_id,
dataset_rows=dataset_rows,
tier=subscription_tier
)
async def validate_forecast_horizon(
self,
tenant_id: str,
horizon_days: int,
subscription_tier: str
):
"""
Validate forecast horizon against subscription tier limits
Args:
tenant_id: Tenant ID
horizon_days: Number of days to forecast
subscription_tier: User's subscription tier
Raises:
HTTPException: If horizon exceeds limit
"""
# Forecast horizon limits per tier
horizon_limits = {
'starter': 7,
'professional': 90,
'enterprise': 365 # Practically unlimited
}
limit = horizon_limits.get(subscription_tier.lower(), 7)
if horizon_days > limit:
self.logger.warning(
"forecast_horizon_exceeded",
tenant_id=tenant_id,
horizon_days=horizon_days,
limit=limit,
tier=subscription_tier
)
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail={
"error": "forecast_horizon_limit_exceeded",
"message": f"Forecast horizon limited to {limit} days for {subscription_tier} tier",
"requested_horizon": horizon_days,
"limit": limit,
"tier": subscription_tier,
"upgrade_url": "/app/settings/profile"
}
)
self.logger.info(
"forecast_horizon_validated",
tenant_id=tenant_id,
horizon_days=horizon_days,
tier=subscription_tier
)
async def validate_historical_data_access(
self,
tenant_id: str,
days_back: int,
subscription_tier: str
):
"""
Validate historical data access against subscription tier limits
Args:
tenant_id: Tenant ID
days_back: Number of days of historical data requested
subscription_tier: User's subscription tier
Raises:
HTTPException: If historical data access exceeds limit
"""
# Historical data limits per tier
history_limits = {
'starter': 7,
'professional': 90,
'enterprise': None # Unlimited
}
limit = history_limits.get(subscription_tier.lower(), 7)
if limit is not None and days_back > limit:
self.logger.warning(
"historical_data_limit_exceeded",
tenant_id=tenant_id,
days_back=days_back,
limit=limit,
tier=subscription_tier
)
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail={
"error": "historical_data_limit_exceeded",
"message": f"Historical data limited to {limit} days for {subscription_tier} tier",
"requested_days": days_back,
"limit": limit,
"tier": subscription_tier,
"upgrade_url": "/app/settings/profile"
}
)
self.logger.info(
"historical_data_access_validated",
tenant_id=tenant_id,
days_back=days_back,
tier=subscription_tier
)
def create_rate_limiter(redis_client) -> RateLimiter:
"""Factory function to create rate limiter"""
return RateLimiter(redis_client)