Add role-based filtering and imporve code

This commit is contained in:
Urtzi Alfaro
2025-10-15 16:12:49 +02:00
parent 96ad5c6692
commit 8f9e9a7edc
158 changed files with 11033 additions and 1544 deletions

View File

@@ -14,6 +14,7 @@ from urllib.parse import urljoin
from shared.auth.jwt_handler import JWTHandler
from shared.config.base import BaseServiceSettings
from shared.clients.circuit_breaker import CircuitBreaker, CircuitBreakerOpenException
logger = structlog.get_logger()
@@ -91,11 +92,19 @@ class BaseServiceClient(ABC):
self.config = config
self.gateway_url = config.GATEWAY_URL
self.authenticator = ServiceAuthenticator(service_name, config)
# HTTP client configuration
self.timeout = config.HTTP_TIMEOUT
self.retries = config.HTTP_RETRIES
self.retry_delay = config.HTTP_RETRY_DELAY
# Circuit breaker for fault tolerance
self.circuit_breaker = CircuitBreaker(
service_name=f"{service_name}-client",
failure_threshold=5,
timeout=60,
success_threshold=2
)
@abstractmethod
def get_service_base_path(self) -> str:
@@ -113,8 +122,8 @@ class BaseServiceClient(ABC):
timeout: Optional[Union[int, httpx.Timeout]] = None
) -> Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]:
"""
Make an authenticated request to another service via gateway
Make an authenticated request to another service via gateway with circuit breaker protection.
Args:
method: HTTP method (GET, POST, PUT, DELETE)
endpoint: API endpoint (will be prefixed with service base path)
@@ -123,10 +132,53 @@ class BaseServiceClient(ABC):
params: Query parameters
headers: Additional headers
timeout: Request timeout override
Returns:
Response data or None if request failed
"""
try:
# Wrap request in circuit breaker
return await self.circuit_breaker.call(
self._do_request,
method,
endpoint,
tenant_id,
data,
params,
headers,
timeout
)
except CircuitBreakerOpenException as e:
logger.error(
"Circuit breaker open - request rejected",
service=self.service_name,
endpoint=endpoint,
error=str(e)
)
return None
except Exception as e:
logger.error(
"Unexpected error in request",
service=self.service_name,
endpoint=endpoint,
error=str(e)
)
return None
async def _do_request(
self,
method: str,
endpoint: str,
tenant_id: Optional[str] = None,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[Union[int, httpx.Timeout]] = None
) -> Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]:
"""
Internal method to execute HTTP request with retries.
Called by _make_request through circuit breaker.
"""
try:
# Get service token
token = await self.authenticator.get_service_token()
@@ -135,7 +187,11 @@ class BaseServiceClient(ABC):
request_headers = self.authenticator.get_request_headers(tenant_id)
request_headers["Authorization"] = f"Bearer {token}"
request_headers["Content-Type"] = "application/json"
# Propagate request ID for distributed tracing if provided
if headers and "X-Request-ID" in headers:
request_headers["X-Request-ID"] = headers["X-Request-ID"]
if headers:
request_headers.update(headers)

View File

@@ -0,0 +1,215 @@
"""
Circuit Breaker implementation for inter-service communication
Prevents cascading failures by failing fast when a service is unhealthy
"""
import time
import structlog
from enum import Enum
from typing import Callable, Any, Optional
import asyncio
logger = structlog.get_logger()
class CircuitState(Enum):
"""Circuit breaker states"""
CLOSED = "closed" # Normal operation, requests pass through
OPEN = "open" # Service is failing, reject requests immediately
HALF_OPEN = "half_open" # Testing if service has recovered
class CircuitBreakerOpenException(Exception):
"""Raised when circuit breaker is open and rejects a request"""
pass
class CircuitBreaker:
"""
Circuit breaker pattern implementation for preventing cascading failures.
States:
- CLOSED: Normal operation, all requests pass through
- OPEN: Service is failing, reject all requests immediately
- HALF_OPEN: Testing recovery, allow one request through
Transitions:
- CLOSED -> OPEN: After failure_threshold consecutive failures
- OPEN -> HALF_OPEN: After timeout seconds have passed
- HALF_OPEN -> CLOSED: If test request succeeds
- HALF_OPEN -> OPEN: If test request fails
"""
def __init__(
self,
service_name: str,
failure_threshold: int = 5,
timeout: int = 60,
success_threshold: int = 2
):
"""
Initialize circuit breaker.
Args:
service_name: Name of the service being protected
failure_threshold: Number of consecutive failures before opening circuit
timeout: Seconds to wait before attempting recovery (half-open state)
success_threshold: Consecutive successes needed to close from half-open
"""
self.service_name = service_name
self.failure_threshold = failure_threshold
self.timeout = timeout
self.success_threshold = success_threshold
self.state = CircuitState.CLOSED
self.failure_count = 0
self.success_count = 0
self.last_failure_time: Optional[float] = None
self._lock = asyncio.Lock()
logger.info(
"Circuit breaker initialized",
service=service_name,
failure_threshold=failure_threshold,
timeout=timeout
)
async def call(self, func: Callable, *args, **kwargs) -> Any:
"""
Execute function with circuit breaker protection.
Args:
func: Async function to execute
*args, **kwargs: Arguments to pass to func
Returns:
Result from func
Raises:
CircuitBreakerOpenException: If circuit is open
Exception: Any exception raised by func
"""
async with self._lock:
# Check if circuit should transition to half-open
if self.state == CircuitState.OPEN:
if self._should_attempt_reset():
logger.info(
"Circuit breaker transitioning to half-open",
service=self.service_name
)
self.state = CircuitState.HALF_OPEN
self.success_count = 0
else:
# Circuit is open, reject request
raise CircuitBreakerOpenException(
f"Circuit breaker is OPEN for {self.service_name}. "
f"Service will be retried in {self._time_until_retry():.0f} seconds."
)
# Execute function
try:
result = await func(*args, **kwargs)
await self._on_success()
return result
except Exception as e:
await self._on_failure(e)
raise
def _should_attempt_reset(self) -> bool:
"""Check if enough time has passed to attempt recovery"""
if self.last_failure_time is None:
return True
return time.time() - self.last_failure_time >= self.timeout
def _time_until_retry(self) -> float:
"""Calculate seconds until next retry attempt"""
if self.last_failure_time is None:
return 0.0
elapsed = time.time() - self.last_failure_time
return max(0.0, self.timeout - elapsed)
async def _on_success(self):
"""Handle successful request"""
async with self._lock:
self.failure_count = 0
if self.state == CircuitState.HALF_OPEN:
self.success_count += 1
logger.debug(
"Circuit breaker success in half-open state",
service=self.service_name,
success_count=self.success_count,
success_threshold=self.success_threshold
)
if self.success_count >= self.success_threshold:
logger.info(
"Circuit breaker closing - service recovered",
service=self.service_name
)
self.state = CircuitState.CLOSED
self.success_count = 0
async def _on_failure(self, exception: Exception):
"""Handle failed request"""
async with self._lock:
self.failure_count += 1
self.last_failure_time = time.time()
if self.state == CircuitState.HALF_OPEN:
logger.warning(
"Circuit breaker opening - recovery attempt failed",
service=self.service_name,
error=str(exception)
)
self.state = CircuitState.OPEN
self.success_count = 0
elif self.state == CircuitState.CLOSED:
logger.warning(
"Circuit breaker failure recorded",
service=self.service_name,
failure_count=self.failure_count,
threshold=self.failure_threshold,
error=str(exception)
)
if self.failure_count >= self.failure_threshold:
logger.error(
"Circuit breaker opening - failure threshold reached",
service=self.service_name,
failure_count=self.failure_count
)
self.state = CircuitState.OPEN
def get_state(self) -> str:
"""Get current circuit breaker state"""
return self.state.value
def is_closed(self) -> bool:
"""Check if circuit is closed (normal operation)"""
return self.state == CircuitState.CLOSED
def is_open(self) -> bool:
"""Check if circuit is open (failing fast)"""
return self.state == CircuitState.OPEN
def is_half_open(self) -> bool:
"""Check if circuit is half-open (testing recovery)"""
return self.state == CircuitState.HALF_OPEN
async def reset(self):
"""Manually reset circuit breaker to closed state"""
async with self._lock:
logger.info(
"Circuit breaker manually reset",
service=self.service_name,
previous_state=self.state.value
)
self.state = CircuitState.CLOSED
self.failure_count = 0
self.success_count = 0
self.last_failure_time = None

View File

@@ -0,0 +1,205 @@
"""
Nominatim Client for geocoding and address search
"""
import structlog
import httpx
from typing import Optional, List, Dict, Any
from shared.config.base import BaseServiceSettings
logger = structlog.get_logger()
class NominatimClient:
"""
Client for Nominatim geocoding service.
Provides address search and geocoding capabilities for the bakery onboarding flow.
"""
def __init__(self, config: BaseServiceSettings):
self.config = config
self.nominatim_url = getattr(
config,
"NOMINATIM_SERVICE_URL",
"http://nominatim-service:8080"
)
self.timeout = 30
async def search_address(
self,
query: str,
country_codes: str = "es",
limit: int = 5,
addressdetails: bool = True
) -> List[Dict[str, Any]]:
"""
Search for addresses matching a query.
Args:
query: Address search query (e.g., "Calle Mayor 1, Madrid")
country_codes: Limit search to country codes (default: "es" for Spain)
limit: Maximum number of results (default: 5)
addressdetails: Include detailed address breakdown (default: True)
Returns:
List of geocoded results with lat, lon, and address details
Example:
results = await nominatim.search_address("Calle Mayor 1, Madrid")
if results:
lat = results[0]["lat"]
lon = results[0]["lon"]
display_name = results[0]["display_name"]
"""
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.nominatim_url}/search",
params={
"q": query,
"format": "json",
"countrycodes": country_codes,
"addressdetails": 1 if addressdetails else 0,
"limit": limit
}
)
if response.status_code == 200:
results = response.json()
logger.info(
"Address search completed",
query=query,
results_count=len(results)
)
return results
else:
logger.error(
"Nominatim search failed",
query=query,
status_code=response.status_code,
response=response.text
)
return []
except httpx.TimeoutException:
logger.error("Nominatim search timeout", query=query)
return []
except Exception as e:
logger.error("Nominatim search error", query=query, error=str(e))
return []
async def geocode_address(
self,
street: str,
city: str,
postal_code: Optional[str] = None,
country: str = "Spain"
) -> Optional[Dict[str, Any]]:
"""
Geocode a structured address to coordinates.
Args:
street: Street name and number
city: City name
postal_code: Optional postal code
country: Country name (default: "Spain")
Returns:
Dict with lat, lon, and display_name, or None if not found
Example:
location = await nominatim.geocode_address(
street="Calle Mayor 1",
city="Madrid",
postal_code="28013"
)
if location:
lat, lon = location["lat"], location["lon"]
"""
# Build structured query
query_parts = [street, city]
if postal_code:
query_parts.append(postal_code)
query_parts.append(country)
query = ", ".join(query_parts)
results = await self.search_address(query, limit=1)
if results:
return results[0]
return None
async def reverse_geocode(
self,
latitude: float,
longitude: float
) -> Optional[Dict[str, Any]]:
"""
Reverse geocode coordinates to an address.
Args:
latitude: Latitude coordinate
longitude: Longitude coordinate
Returns:
Dict with address details, or None if not found
Example:
address = await nominatim.reverse_geocode(40.4168, -3.7038)
if address:
city = address["address"]["city"]
street = address["address"]["road"]
"""
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.nominatim_url}/reverse",
params={
"lat": latitude,
"lon": longitude,
"format": "json",
"addressdetails": 1
}
)
if response.status_code == 200:
result = response.json()
logger.info(
"Reverse geocoding completed",
lat=latitude,
lon=longitude
)
return result
else:
logger.error(
"Nominatim reverse geocoding failed",
lat=latitude,
lon=longitude,
status_code=response.status_code
)
return None
except Exception as e:
logger.error(
"Reverse geocoding error",
lat=latitude,
lon=longitude,
error=str(e)
)
return None
async def health_check(self) -> bool:
"""
Check if Nominatim service is healthy.
Returns:
True if service is responding, False otherwise
"""
try:
async with httpx.AsyncClient(timeout=5) as client:
response = await client.get(f"{self.nominatim_url}/status")
return response.status_code == 200
except Exception as e:
logger.warning("Nominatim health check failed", error=str(e))
return False

View File

@@ -0,0 +1,179 @@
"""
OpenTelemetry distributed tracing integration
Provides end-to-end request tracking across all services
"""
import structlog
from typing import Optional
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
logger = structlog.get_logger()
def setup_tracing(
app,
service_name: str,
service_version: str = "1.0.0",
jaeger_endpoint: str = "http://jaeger-collector.monitoring:4317"
):
"""
Setup OpenTelemetry distributed tracing for a FastAPI service.
Automatically instruments:
- FastAPI endpoints
- HTTPX client requests (inter-service calls)
- Redis operations
- PostgreSQL/SQLAlchemy queries
Args:
app: FastAPI application instance
service_name: Name of the service (e.g., "auth-service")
service_version: Version of the service
jaeger_endpoint: Jaeger collector gRPC endpoint
Example:
from shared.monitoring.tracing import setup_tracing
app = FastAPI(title="Auth Service")
setup_tracing(app, "auth-service")
"""
try:
# Create resource with service information
resource = Resource(attributes={
SERVICE_NAME: service_name,
SERVICE_VERSION: service_version,
"deployment.environment": "production"
})
# Configure tracer provider
tracer_provider = TracerProvider(resource=resource)
trace.set_tracer_provider(tracer_provider)
# Configure OTLP exporter to send to Jaeger
otlp_exporter = OTLPSpanExporter(
endpoint=jaeger_endpoint,
insecure=True # Use TLS in production
)
# Add span processor with batching for performance
span_processor = BatchSpanProcessor(otlp_exporter)
tracer_provider.add_span_processor(span_processor)
# Auto-instrument FastAPI
FastAPIInstrumentor.instrument_app(
app,
tracer_provider=tracer_provider,
excluded_urls="health,metrics" # Don't trace health/metrics endpoints
)
# Auto-instrument HTTPX (inter-service communication)
HTTPXClientInstrumentor().instrument(tracer_provider=tracer_provider)
# Auto-instrument Redis
try:
RedisInstrumentor().instrument(tracer_provider=tracer_provider)
except Exception as e:
logger.warning(f"Failed to instrument Redis: {e}")
# Auto-instrument PostgreSQL (psycopg2) - skip if not available
# Most services use asyncpg instead of psycopg2
# try:
# Psycopg2Instrumentor().instrument(tracer_provider=tracer_provider)
# except Exception as e:
# logger.warning(f"Failed to instrument Psycopg2: {e}")
# Auto-instrument SQLAlchemy
try:
SQLAlchemyInstrumentor().instrument(tracer_provider=tracer_provider)
except Exception as e:
logger.warning(f"Failed to instrument SQLAlchemy: {e}")
logger.info(
"Distributed tracing configured",
service=service_name,
jaeger_endpoint=jaeger_endpoint
)
except Exception as e:
logger.error(
"Failed to setup tracing - continuing without it",
service=service_name,
error=str(e)
)
def get_current_trace_id() -> Optional[str]:
"""
Get the current trace ID for correlation with logs.
Returns:
Trace ID as hex string, or None if no active trace
"""
span = trace.get_current_span()
if span and span.get_span_context().is_valid:
return format(span.get_span_context().trace_id, '032x')
return None
def get_current_span_id() -> Optional[str]:
"""
Get the current span ID.
Returns:
Span ID as hex string, or None if no active span
"""
span = trace.get_current_span()
if span and span.get_span_context().is_valid:
return format(span.get_span_context().span_id, '016x')
return None
def add_trace_attributes(**attributes):
"""
Add custom attributes to the current span.
Example:
add_trace_attributes(
user_id="123",
tenant_id="abc",
operation="user_registration"
)
"""
span = trace.get_current_span()
if span and span.get_span_context().is_valid:
for key, value in attributes.items():
span.set_attribute(key, str(value))
def add_trace_event(name: str, **attributes):
"""
Add an event to the current span (for important operations).
Example:
add_trace_event("user_authenticated", user_id="123", method="jwt")
"""
span = trace.get_current_span()
if span and span.get_span_context().is_valid:
span.add_event(name, attributes)
def record_exception(exception: Exception):
"""
Record an exception in the current span.
Args:
exception: The exception to record
"""
span = trace.get_current_span()
if span and span.get_span_context().is_valid:
span.record_exception(exception)
span.set_status(trace.Status(trace.StatusCode.ERROR, str(exception)))

View File

@@ -0,0 +1,33 @@
"""
Redis utilities for Bakery-IA platform
Provides Redis connection management and rate limiting
"""
from shared.redis_utils.client import (
RedisConnectionManager,
get_redis_manager,
initialize_redis,
get_redis_client,
close_redis,
redis_context,
set_with_ttl,
get_value,
increment_counter,
get_keys_pattern
)
__all__ = [
# Connection management
"RedisConnectionManager",
"get_redis_manager",
"initialize_redis",
"get_redis_client",
"close_redis",
"redis_context",
# Convenience functions
"set_with_ttl",
"get_value",
"increment_counter",
"get_keys_pattern",
]

View File

@@ -0,0 +1,329 @@
"""
Redis client initialization and connection management
Provides standardized Redis connection for all services
"""
import redis.asyncio as redis
from typing import Optional
import structlog
from contextlib import asynccontextmanager
logger = structlog.get_logger()
class RedisConnectionManager:
"""
Manages Redis connections with connection pooling and error handling
Thread-safe singleton pattern for sharing connections across service
"""
def __init__(self):
self._client: Optional[redis.Redis] = None
self._pool: Optional[redis.ConnectionPool] = None
self.logger = logger
async def initialize(
self,
redis_url: str,
db: int = 0,
max_connections: int = 50,
decode_responses: bool = True,
retry_on_timeout: bool = True,
socket_keepalive: bool = True,
health_check_interval: int = 30
):
"""
Initialize Redis connection with pool
Args:
redis_url: Redis connection URL (redis://[:password]@host:port)
db: Database number (0-15)
max_connections: Maximum connections in pool
decode_responses: Automatically decode responses to strings
retry_on_timeout: Retry on timeout errors
socket_keepalive: Enable TCP keepalive
health_check_interval: Health check interval in seconds
"""
try:
# Create connection pool
self._pool = redis.ConnectionPool.from_url(
redis_url,
db=db,
max_connections=max_connections,
decode_responses=decode_responses,
retry_on_timeout=retry_on_timeout,
socket_keepalive=socket_keepalive,
health_check_interval=health_check_interval
)
# Create Redis client with pool
self._client = redis.Redis(connection_pool=self._pool)
# Test connection
await self._client.ping()
self.logger.info(
"redis_initialized",
redis_url=redis_url.split("@")[-1], # Log only host:port, not password
db=db,
max_connections=max_connections
)
except Exception as e:
self.logger.error(
"redis_initialization_failed",
error=str(e),
redis_url=redis_url.split("@")[-1]
)
raise
async def close(self):
"""Close Redis connection and pool"""
if self._client:
await self._client.close()
self.logger.info("redis_client_closed")
if self._pool:
await self._pool.disconnect()
self.logger.info("redis_pool_closed")
def get_client(self) -> redis.Redis:
"""
Get Redis client instance
Returns:
Redis client
Raises:
RuntimeError: If client not initialized
"""
if self._client is None:
raise RuntimeError("Redis client not initialized. Call initialize() first.")
return self._client
async def health_check(self) -> bool:
"""
Check Redis connection health
Returns:
bool: True if healthy, False otherwise
"""
try:
if self._client is None:
return False
await self._client.ping()
return True
except Exception as e:
self.logger.error("redis_health_check_failed", error=str(e))
return False
async def get_info(self) -> dict:
"""
Get Redis server information
Returns:
dict: Redis INFO command output
"""
try:
if self._client is None:
return {}
return await self._client.info()
except Exception as e:
self.logger.error("redis_info_failed", error=str(e))
return {}
async def flush_db(self):
"""
Flush current database (USE WITH CAUTION)
Only for development/testing
"""
try:
if self._client is None:
raise RuntimeError("Redis client not initialized")
await self._client.flushdb()
self.logger.warning("redis_database_flushed")
except Exception as e:
self.logger.error("redis_flush_failed", error=str(e))
raise
# Global connection manager instance
_redis_manager: Optional[RedisConnectionManager] = None
async def get_redis_manager() -> RedisConnectionManager:
"""
Get or create global Redis manager instance
Returns:
RedisConnectionManager instance
"""
global _redis_manager
if _redis_manager is None:
_redis_manager = RedisConnectionManager()
return _redis_manager
async def initialize_redis(
redis_url: str,
db: int = 0,
max_connections: int = 50,
**kwargs
) -> redis.Redis:
"""
Initialize Redis and return client
Args:
redis_url: Redis connection URL
db: Database number
max_connections: Maximum connections in pool
**kwargs: Additional connection parameters
Returns:
Redis client instance
"""
manager = await get_redis_manager()
await manager.initialize(
redis_url=redis_url,
db=db,
max_connections=max_connections,
**kwargs
)
return manager.get_client()
async def get_redis_client() -> redis.Redis:
"""
Get initialized Redis client
Returns:
Redis client instance
Raises:
RuntimeError: If Redis not initialized
"""
manager = await get_redis_manager()
return manager.get_client()
async def close_redis():
"""Close Redis connections"""
global _redis_manager
if _redis_manager:
await _redis_manager.close()
_redis_manager = None
@asynccontextmanager
async def redis_context(redis_url: str, db: int = 0):
"""
Context manager for Redis connections
Usage:
async with redis_context(settings.REDIS_URL) as client:
await client.set("key", "value")
Args:
redis_url: Redis connection URL
db: Database number
Yields:
Redis client
"""
client = None
try:
client = await initialize_redis(redis_url, db=db)
yield client
finally:
if client:
await close_redis()
# Convenience functions for common operations
async def set_with_ttl(key: str, value: str, ttl: int) -> bool:
"""
Set key with TTL
Args:
key: Redis key
value: Value to set
ttl: Time to live in seconds
Returns:
bool: True if successful
"""
try:
client = await get_redis_client()
await client.setex(key, ttl, value)
return True
except Exception as e:
logger.error("redis_set_failed", key=key, error=str(e))
return False
async def get_value(key: str) -> Optional[str]:
"""
Get value by key
Args:
key: Redis key
Returns:
Value or None if not found
"""
try:
client = await get_redis_client()
return await client.get(key)
except Exception as e:
logger.error("redis_get_failed", key=key, error=str(e))
return None
async def increment_counter(key: str, amount: int = 1, ttl: Optional[int] = None) -> int:
"""
Increment counter with optional TTL
Args:
key: Redis key
amount: Amount to increment
ttl: Time to live in seconds (sets on first increment)
Returns:
New counter value
"""
try:
client = await get_redis_client()
new_value = await client.incrby(key, amount)
# Set TTL if specified and key is new (value == amount)
if ttl and new_value == amount:
await client.expire(key, ttl)
return new_value
except Exception as e:
logger.error("redis_increment_failed", key=key, error=str(e))
return 0
async def get_keys_pattern(pattern: str) -> list:
"""
Get keys matching pattern
Args:
pattern: Redis key pattern (e.g., "quota:*")
Returns:
List of matching keys
"""
try:
client = await get_redis_client()
return await client.keys(pattern)
except Exception as e:
logger.error("redis_keys_failed", pattern=pattern, error=str(e))
return []

View File

@@ -0,0 +1,9 @@
# OpenTelemetry dependencies for distributed tracing
opentelemetry-api==1.21.0
opentelemetry-sdk==1.21.0
opentelemetry-instrumentation-fastapi==0.42b0
opentelemetry-instrumentation-httpx==0.42b0
opentelemetry-instrumentation-redis==0.42b0
# opentelemetry-instrumentation-psycopg2==0.42b0 # Commented out - not all services use psycopg2
opentelemetry-instrumentation-sqlalchemy==0.42b0
opentelemetry-exporter-otlp-proto-grpc==1.21.0

View File

@@ -0,0 +1,31 @@
"""
Security utilities for RBAC, audit logging, and rate limiting
"""
from shared.security.audit_logger import (
AuditLogger,
AuditSeverity,
AuditAction,
create_audit_logger,
create_audit_log_model
)
from shared.security.rate_limiter import (
RateLimiter,
QuotaType,
create_rate_limiter
)
__all__ = [
# Audit logging
"AuditLogger",
"AuditSeverity",
"AuditAction",
"create_audit_logger",
"create_audit_log_model",
# Rate limiting
"RateLimiter",
"QuotaType",
"create_rate_limiter",
]

View File

@@ -0,0 +1,317 @@
"""
Audit logging system for tracking critical operations across all services
"""
import uuid
from datetime import datetime, timezone
from typing import Optional, Dict, Any
from enum import Enum
import structlog
from sqlalchemy import Column, String, DateTime, Text, Index
from sqlalchemy.dialects.postgresql import UUID, JSON
logger = structlog.get_logger()
class AuditSeverity(str, Enum):
"""Severity levels for audit events"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class AuditAction(str, Enum):
"""Common audit action types"""
CREATE = "create"
READ = "read"
UPDATE = "update"
DELETE = "delete"
APPROVE = "approve"
REJECT = "reject"
CANCEL = "cancel"
EXPORT = "export"
IMPORT = "import"
INVITE = "invite"
REMOVE = "remove"
UPGRADE = "upgrade"
DOWNGRADE = "downgrade"
DEACTIVATE = "deactivate"
ACTIVATE = "activate"
def create_audit_log_model(Base):
"""
Factory function to create AuditLog model for any service
Each service has its own audit_logs table in their database
Usage in service models/__init__.py:
from shared.database.base import Base
from shared.security import create_audit_log_model
AuditLog = create_audit_log_model(Base)
Args:
Base: SQLAlchemy declarative base for the service
Returns:
AuditLog model class bound to the service's Base
"""
class AuditLog(Base):
"""
Audit log model for tracking critical operations
Each service has its own audit_logs table for data locality
"""
__tablename__ = "audit_logs"
# Primary identification
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Tenant and user context
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Action details
action = Column(String(100), nullable=False, index=True) # create, update, delete, etc.
resource_type = Column(String(100), nullable=False, index=True) # supplier, recipe, order, etc.
resource_id = Column(String(255), nullable=True, index=True)
# Severity and categorization
severity = Column(
String(20),
nullable=False,
default="medium",
index=True
) # low, medium, high, critical
# Service identification
service_name = Column(String(100), nullable=False, index=True)
# Details
description = Column(Text, nullable=True)
# Audit trail data
changes = Column(JSON, nullable=True) # Before/after values for updates
audit_metadata = Column(JSON, nullable=True) # Additional context
# Request context
ip_address = Column(String(45), nullable=True) # IPv4 or IPv6
user_agent = Column(Text, nullable=True)
endpoint = Column(String(255), nullable=True)
method = Column(String(10), nullable=True) # GET, POST, PUT, DELETE
# Timestamps
created_at = Column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
index=True
)
# Composite indexes for common query patterns
__table_args__ = (
Index('idx_audit_tenant_created', 'tenant_id', 'created_at'),
Index('idx_audit_user_created', 'user_id', 'created_at'),
Index('idx_audit_resource_type_action', 'resource_type', 'action'),
Index('idx_audit_severity_created', 'severity', 'created_at'),
Index('idx_audit_service_created', 'service_name', 'created_at'),
)
def __repr__(self):
return (
f"<AuditLog(id={self.id}, tenant={self.tenant_id}, "
f"action={self.action}, resource={self.resource_type}, "
f"severity={self.severity})>"
)
def to_dict(self):
"""Convert audit log to dictionary"""
return {
"id": str(self.id),
"tenant_id": str(self.tenant_id),
"user_id": str(self.user_id),
"action": self.action,
"resource_type": self.resource_type,
"resource_id": self.resource_id,
"severity": self.severity,
"service_name": self.service_name,
"description": self.description,
"changes": self.changes,
"metadata": self.audit_metadata,
"ip_address": self.ip_address,
"user_agent": self.user_agent,
"endpoint": self.endpoint,
"method": self.method,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
return AuditLog
class AuditLogger:
"""Service for logging audit events"""
def __init__(self, service_name: str):
self.service_name = service_name
self.logger = logger.bind(service=service_name)
async def log_event(
self,
db_session,
tenant_id: str,
user_id: str,
action: str,
resource_type: str,
resource_id: Optional[str] = None,
severity: str = "medium",
description: Optional[str] = None,
changes: Optional[Dict[str, Any]] = None,
audit_metadata: Optional[Dict[str, Any]] = None,
endpoint: Optional[str] = None,
method: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
):
"""
Log an audit event
Args:
db_session: Database session
tenant_id: Tenant ID
user_id: User ID who performed the action
action: Action performed (create, update, delete, etc.)
resource_type: Type of resource (user, sale, recipe, etc.)
resource_id: ID of the resource affected
severity: Severity level (low, medium, high, critical)
description: Human-readable description
changes: Dictionary of before/after values for updates
audit_metadata: Additional context
endpoint: API endpoint
method: HTTP method
ip_address: Client IP address
user_agent: Client user agent
"""
try:
audit_log = AuditLog(
tenant_id=uuid.UUID(tenant_id) if isinstance(tenant_id, str) else tenant_id,
user_id=uuid.UUID(user_id) if isinstance(user_id, str) else user_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
severity=severity,
service_name=self.service_name,
description=description,
changes=changes,
audit_metadata=audit_metadata,
endpoint=endpoint,
method=method,
ip_address=ip_address,
user_agent=user_agent,
)
db_session.add(audit_log)
await db_session.commit()
self.logger.info(
"audit_event_logged",
tenant_id=str(tenant_id),
user_id=str(user_id),
action=action,
resource_type=resource_type,
resource_id=resource_id,
severity=severity,
)
except Exception as e:
self.logger.error(
"audit_log_failed",
error=str(e),
tenant_id=str(tenant_id),
user_id=str(user_id),
action=action,
)
# Don't raise - audit logging should not block operations
async def log_deletion(
self,
db_session,
tenant_id: str,
user_id: str,
resource_type: str,
resource_id: str,
resource_data: Optional[Dict[str, Any]] = None,
**kwargs
):
"""Convenience method for logging deletions"""
return await self.log_event(
db_session=db_session,
tenant_id=tenant_id,
user_id=user_id,
action=AuditAction.DELETE.value,
resource_type=resource_type,
resource_id=resource_id,
severity=AuditSeverity.HIGH.value,
description=f"Deleted {resource_type} {resource_id}",
audit_metadata={"deleted_data": resource_data} if resource_data else None,
**kwargs
)
async def log_role_change(
self,
db_session,
tenant_id: str,
user_id: str,
target_user_id: str,
old_role: str,
new_role: str,
**kwargs
):
"""Convenience method for logging role changes"""
return await self.log_event(
db_session=db_session,
tenant_id=tenant_id,
user_id=user_id,
action=AuditAction.UPDATE.value,
resource_type="user_role",
resource_id=target_user_id,
severity=AuditSeverity.HIGH.value,
description=f"Changed user role from {old_role} to {new_role}",
changes={
"before": {"role": old_role},
"after": {"role": new_role}
},
**kwargs
)
async def log_subscription_change(
self,
db_session,
tenant_id: str,
user_id: str,
action: str,
old_plan: Optional[str] = None,
new_plan: Optional[str] = None,
**kwargs
):
"""Convenience method for logging subscription changes"""
return await self.log_event(
db_session=db_session,
tenant_id=tenant_id,
user_id=user_id,
action=action,
resource_type="subscription",
resource_id=tenant_id,
severity=AuditSeverity.CRITICAL.value,
description=f"Subscription {action}: {old_plan} -> {new_plan}" if old_plan else f"Subscription {action}: {new_plan}",
changes={
"before": {"plan": old_plan} if old_plan else None,
"after": {"plan": new_plan} if new_plan else None
},
**kwargs
)
def create_audit_logger(service_name: str) -> AuditLogger:
"""Factory function to create audit logger for a service"""
return AuditLogger(service_name)

View File

@@ -0,0 +1,388 @@
"""
Rate limiting and quota management system for subscription-based features
"""
import time
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
from enum import Enum
import structlog
from fastapi import HTTPException, status
logger = structlog.get_logger()
class QuotaType(str, Enum):
"""Types of quotas"""
FORECAST_GENERATION = "forecast_generation"
TRAINING_JOBS = "training_jobs"
BULK_IMPORTS = "bulk_imports"
POS_SYNC = "pos_sync"
API_CALLS = "api_calls"
DEMO_SESSIONS = "demo_sessions"
class RateLimiter:
"""
Redis-based rate limiter for subscription tier quotas
"""
def __init__(self, redis_client):
"""
Initialize rate limiter
Args:
redis_client: Redis client for storing quota counters
"""
self.redis = redis_client
self.logger = logger
def _get_quota_key(self, tenant_id: str, quota_type: str, period: str = "daily") -> str:
"""Generate Redis key for quota tracking"""
date_str = datetime.utcnow().strftime("%Y-%m-%d")
return f"quota:{period}:{quota_type}:{tenant_id}:{date_str}"
def _get_dataset_size_key(self, tenant_id: str) -> str:
"""Generate Redis key for dataset size tracking"""
return f"dataset_size:{tenant_id}"
async def check_and_increment_quota(
self,
tenant_id: str,
quota_type: str,
limit: Optional[int],
period: int = 86400 # 24 hours in seconds
) -> Dict[str, Any]:
"""
Check if quota allows action and increment counter
Args:
tenant_id: Tenant ID
quota_type: Type of quota to check
limit: Maximum allowed count (None = unlimited)
period: Time period in seconds (default: 24 hours)
Returns:
Dict with:
- allowed: bool
- current: int (current count)
- limit: Optional[int]
- reset_at: datetime (when quota resets)
Raises:
HTTPException: If quota is exceeded
"""
if limit is None:
# Unlimited quota
return {
"allowed": True,
"current": 0,
"limit": None,
"reset_at": None
}
key = self._get_quota_key(tenant_id, quota_type)
try:
# Get current count
current = await self.redis.get(key)
current_count = int(current) if current else 0
# Check if limit exceeded
if current_count >= limit:
ttl = await self.redis.ttl(key)
reset_at = datetime.utcnow() + timedelta(seconds=ttl if ttl > 0 else period)
self.logger.warning(
"quota_exceeded",
tenant_id=tenant_id,
quota_type=quota_type,
current=current_count,
limit=limit,
reset_at=reset_at.isoformat()
)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail={
"error": "quota_exceeded",
"message": f"Daily quota exceeded for {quota_type}",
"current": current_count,
"limit": limit,
"reset_at": reset_at.isoformat(),
"quota_type": quota_type
}
)
# Increment counter
pipe = self.redis.pipeline()
pipe.incr(key)
pipe.expire(key, period)
await pipe.execute()
new_count = current_count + 1
ttl = await self.redis.ttl(key)
reset_at = datetime.utcnow() + timedelta(seconds=ttl if ttl > 0 else period)
self.logger.info(
"quota_incremented",
tenant_id=tenant_id,
quota_type=quota_type,
current=new_count,
limit=limit
)
return {
"allowed": True,
"current": new_count,
"limit": limit,
"reset_at": reset_at
}
except HTTPException:
raise
except Exception as e:
self.logger.error(
"quota_check_failed",
error=str(e),
tenant_id=tenant_id,
quota_type=quota_type
)
# Fail open - allow the operation
return {
"allowed": True,
"current": 0,
"limit": limit,
"reset_at": None
}
async def get_current_usage(
self,
tenant_id: str,
quota_type: str
) -> Dict[str, Any]:
"""
Get current quota usage without incrementing
Args:
tenant_id: Tenant ID
quota_type: Type of quota to check
Returns:
Dict with current usage information
"""
key = self._get_quota_key(tenant_id, quota_type)
try:
current = await self.redis.get(key)
current_count = int(current) if current else 0
ttl = await self.redis.ttl(key)
reset_at = datetime.utcnow() + timedelta(seconds=ttl) if ttl > 0 else None
return {
"current": current_count,
"reset_at": reset_at
}
except Exception as e:
self.logger.error(
"usage_check_failed",
error=str(e),
tenant_id=tenant_id,
quota_type=quota_type
)
return {
"current": 0,
"reset_at": None
}
async def reset_quota(self, tenant_id: str, quota_type: str):
"""
Reset quota for a tenant (admin function)
Args:
tenant_id: Tenant ID
quota_type: Type of quota to reset
"""
key = self._get_quota_key(tenant_id, quota_type)
try:
await self.redis.delete(key)
self.logger.info(
"quota_reset",
tenant_id=tenant_id,
quota_type=quota_type
)
except Exception as e:
self.logger.error(
"quota_reset_failed",
error=str(e),
tenant_id=tenant_id,
quota_type=quota_type
)
async def validate_dataset_size(
self,
tenant_id: str,
dataset_rows: int,
subscription_tier: str
):
"""
Validate dataset size against subscription tier limits
Args:
tenant_id: Tenant ID
dataset_rows: Number of rows in dataset
subscription_tier: User's subscription tier
Raises:
HTTPException: If dataset size exceeds limit
"""
# Dataset size limits per tier
dataset_limits = {
'starter': 1000,
'professional': 10000,
'enterprise': None # Unlimited
}
limit = dataset_limits.get(subscription_tier.lower())
if limit is not None and dataset_rows > limit:
self.logger.warning(
"dataset_size_exceeded",
tenant_id=tenant_id,
dataset_rows=dataset_rows,
limit=limit,
tier=subscription_tier
)
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail={
"error": "dataset_size_limit_exceeded",
"message": f"Dataset size limited to {limit:,} rows for {subscription_tier} tier",
"current_size": dataset_rows,
"limit": limit,
"tier": subscription_tier,
"upgrade_url": "/app/settings/profile"
}
)
self.logger.info(
"dataset_size_validated",
tenant_id=tenant_id,
dataset_rows=dataset_rows,
tier=subscription_tier
)
async def validate_forecast_horizon(
self,
tenant_id: str,
horizon_days: int,
subscription_tier: str
):
"""
Validate forecast horizon against subscription tier limits
Args:
tenant_id: Tenant ID
horizon_days: Number of days to forecast
subscription_tier: User's subscription tier
Raises:
HTTPException: If horizon exceeds limit
"""
# Forecast horizon limits per tier
horizon_limits = {
'starter': 7,
'professional': 90,
'enterprise': 365 # Practically unlimited
}
limit = horizon_limits.get(subscription_tier.lower(), 7)
if horizon_days > limit:
self.logger.warning(
"forecast_horizon_exceeded",
tenant_id=tenant_id,
horizon_days=horizon_days,
limit=limit,
tier=subscription_tier
)
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail={
"error": "forecast_horizon_limit_exceeded",
"message": f"Forecast horizon limited to {limit} days for {subscription_tier} tier",
"requested_horizon": horizon_days,
"limit": limit,
"tier": subscription_tier,
"upgrade_url": "/app/settings/profile"
}
)
self.logger.info(
"forecast_horizon_validated",
tenant_id=tenant_id,
horizon_days=horizon_days,
tier=subscription_tier
)
async def validate_historical_data_access(
self,
tenant_id: str,
days_back: int,
subscription_tier: str
):
"""
Validate historical data access against subscription tier limits
Args:
tenant_id: Tenant ID
days_back: Number of days of historical data requested
subscription_tier: User's subscription tier
Raises:
HTTPException: If historical data access exceeds limit
"""
# Historical data limits per tier
history_limits = {
'starter': 7,
'professional': 90,
'enterprise': None # Unlimited
}
limit = history_limits.get(subscription_tier.lower(), 7)
if limit is not None and days_back > limit:
self.logger.warning(
"historical_data_limit_exceeded",
tenant_id=tenant_id,
days_back=days_back,
limit=limit,
tier=subscription_tier
)
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail={
"error": "historical_data_limit_exceeded",
"message": f"Historical data limited to {limit} days for {subscription_tier} tier",
"requested_days": days_back,
"limit": limit,
"tier": subscription_tier,
"upgrade_url": "/app/settings/profile"
}
)
self.logger.info(
"historical_data_access_validated",
tenant_id=tenant_id,
days_back=days_back,
tier=subscription_tier
)
def create_rate_limiter(redis_client) -> RateLimiter:
"""Factory function to create rate limiter"""
return RateLimiter(redis_client)

View File

@@ -23,6 +23,7 @@ from fastapi.routing import APIRouter
from shared.monitoring import setup_logging
from shared.monitoring.metrics import setup_metrics_early
from shared.monitoring.health_checks import setup_fastapi_health_checks
from shared.monitoring.tracing import setup_tracing
from shared.database.base import DatabaseManager
if TYPE_CHECKING:
@@ -51,6 +52,7 @@ class BaseFastAPIService:
enable_cors: bool = True,
enable_exception_handlers: bool = True,
enable_messaging: bool = False,
enable_tracing: bool = True,
custom_metrics: Optional[Dict[str, Dict[str, Any]]] = None,
alert_service_class: Optional[type] = None
):
@@ -69,6 +71,7 @@ class BaseFastAPIService:
self.enable_cors = enable_cors
self.enable_exception_handlers = enable_exception_handlers
self.enable_messaging = enable_messaging
self.enable_tracing = enable_tracing
self.custom_metrics = custom_metrics or {}
self.alert_service_class = alert_service_class
@@ -106,6 +109,18 @@ class BaseFastAPIService:
if self.enable_metrics:
self.metrics_collector = setup_metrics_early(self.app, self.service_name)
# Setup distributed tracing
if self.enable_tracing:
try:
jaeger_endpoint = os.getenv(
"JAEGER_COLLECTOR_ENDPOINT",
"http://jaeger-collector.monitoring:4317"
)
setup_tracing(self.app, self.service_name, self.version, jaeger_endpoint)
self.logger.info(f"Distributed tracing enabled for {self.service_name}")
except Exception as e:
self.logger.warning(f"Failed to setup tracing, continuing without it: {e}")
# Setup lifespan
self.app.router.lifespan_context = self._create_lifespan()

View File

@@ -0,0 +1,486 @@
"""
Centralized Subscription Plan Configuration
Owner: Tenant Service
Single source of truth for all subscription tiers, quotas, features, and limits
"""
from typing import Optional, Dict, Any, List
from enum import Enum
from decimal import Decimal
class SubscriptionTier(str, Enum):
"""Subscription tier enumeration"""
STARTER = "starter"
PROFESSIONAL = "professional"
ENTERPRISE = "enterprise"
class BillingCycle(str, Enum):
"""Billing cycle options"""
MONTHLY = "monthly"
YEARLY = "yearly"
# ============================================================================
# PRICING CONFIGURATION
# ============================================================================
class PlanPricing:
"""Pricing for each subscription tier"""
MONTHLY_PRICES = {
SubscriptionTier.STARTER: Decimal("49.00"),
SubscriptionTier.PROFESSIONAL: Decimal("149.00"),
SubscriptionTier.ENTERPRISE: Decimal("499.00"), # Base price, custom quotes available
}
YEARLY_PRICES = {
SubscriptionTier.STARTER: Decimal("490.00"), # ~17% discount (2 months free)
SubscriptionTier.PROFESSIONAL: Decimal("1490.00"), # ~17% discount
SubscriptionTier.ENTERPRISE: Decimal("4990.00"), # Base price, custom quotes available
}
@staticmethod
def get_price(tier: str, billing_cycle: str = "monthly") -> Decimal:
"""Get price for tier and billing cycle"""
tier_enum = SubscriptionTier(tier.lower())
if billing_cycle == "yearly":
return PlanPricing.YEARLY_PRICES[tier_enum]
return PlanPricing.MONTHLY_PRICES[tier_enum]
# ============================================================================
# QUOTA LIMITS CONFIGURATION
# ============================================================================
class QuotaLimits:
"""
Resource quotas and limits for each subscription tier
None = Unlimited
"""
# ===== Team & Organization Limits =====
MAX_USERS = {
SubscriptionTier.STARTER: 5,
SubscriptionTier.PROFESSIONAL: 20,
SubscriptionTier.ENTERPRISE: None, # Unlimited
}
MAX_LOCATIONS = {
SubscriptionTier.STARTER: 1,
SubscriptionTier.PROFESSIONAL: 3,
SubscriptionTier.ENTERPRISE: None, # Unlimited
}
# ===== Product & Inventory Limits =====
MAX_PRODUCTS = {
SubscriptionTier.STARTER: 50,
SubscriptionTier.PROFESSIONAL: 500,
SubscriptionTier.ENTERPRISE: None, # Unlimited
}
MAX_RECIPES = {
SubscriptionTier.STARTER: 25,
SubscriptionTier.PROFESSIONAL: 250,
SubscriptionTier.ENTERPRISE: None, # Unlimited
}
MAX_SUPPLIERS = {
SubscriptionTier.STARTER: 10,
SubscriptionTier.PROFESSIONAL: 100,
SubscriptionTier.ENTERPRISE: None, # Unlimited
}
# ===== ML & Analytics Quotas (Daily Limits) =====
TRAINING_JOBS_PER_DAY = {
SubscriptionTier.STARTER: 1,
SubscriptionTier.PROFESSIONAL: 5,
SubscriptionTier.ENTERPRISE: None, # Unlimited
}
FORECAST_GENERATION_PER_DAY = {
SubscriptionTier.STARTER: 10,
SubscriptionTier.PROFESSIONAL: 100,
SubscriptionTier.ENTERPRISE: None, # Unlimited
}
# ===== Data Limits =====
DATASET_SIZE_ROWS = {
SubscriptionTier.STARTER: 1000,
SubscriptionTier.PROFESSIONAL: 10000,
SubscriptionTier.ENTERPRISE: None, # Unlimited
}
FORECAST_HORIZON_DAYS = {
SubscriptionTier.STARTER: 7,
SubscriptionTier.PROFESSIONAL: 90,
SubscriptionTier.ENTERPRISE: 365,
}
HISTORICAL_DATA_ACCESS_DAYS = {
SubscriptionTier.STARTER: 30, # 1 month
SubscriptionTier.PROFESSIONAL: 365, # 1 year
SubscriptionTier.ENTERPRISE: None, # Unlimited
}
# ===== Import/Export Limits =====
BULK_IMPORT_ROWS = {
SubscriptionTier.STARTER: 100,
SubscriptionTier.PROFESSIONAL: 1000,
SubscriptionTier.ENTERPRISE: 10000,
}
BULK_EXPORT_ROWS = {
SubscriptionTier.STARTER: 1000,
SubscriptionTier.PROFESSIONAL: 10000,
SubscriptionTier.ENTERPRISE: None, # Unlimited
}
# ===== Integration Limits =====
POS_SYNC_INTERVAL_MINUTES = {
SubscriptionTier.STARTER: 60, # Hourly
SubscriptionTier.PROFESSIONAL: 15, # Every 15 minutes
SubscriptionTier.ENTERPRISE: 5, # Every 5 minutes (near real-time)
}
API_CALLS_PER_HOUR = {
SubscriptionTier.STARTER: 100,
SubscriptionTier.PROFESSIONAL: 1000,
SubscriptionTier.ENTERPRISE: 10000,
}
WEBHOOK_ENDPOINTS = {
SubscriptionTier.STARTER: 2,
SubscriptionTier.PROFESSIONAL: 10,
SubscriptionTier.ENTERPRISE: None, # Unlimited
}
# ===== Storage Limits =====
FILE_STORAGE_GB = {
SubscriptionTier.STARTER: 1,
SubscriptionTier.PROFESSIONAL: 10,
SubscriptionTier.ENTERPRISE: 100,
}
REPORT_RETENTION_DAYS = {
SubscriptionTier.STARTER: 30,
SubscriptionTier.PROFESSIONAL: 180,
SubscriptionTier.ENTERPRISE: 365,
}
@staticmethod
def get_limit(quota_type: str, tier: str) -> Optional[int]:
"""
Get quota limit for a specific type and tier
Args:
quota_type: Quota type (e.g., 'MAX_USERS')
tier: Subscription tier
Returns:
Optional[int]: Limit value or None for unlimited
"""
tier_enum = SubscriptionTier(tier.lower())
quota_map = {
'MAX_USERS': QuotaLimits.MAX_USERS,
'MAX_LOCATIONS': QuotaLimits.MAX_LOCATIONS,
'MAX_PRODUCTS': QuotaLimits.MAX_PRODUCTS,
'MAX_RECIPES': QuotaLimits.MAX_RECIPES,
'MAX_SUPPLIERS': QuotaLimits.MAX_SUPPLIERS,
'TRAINING_JOBS_PER_DAY': QuotaLimits.TRAINING_JOBS_PER_DAY,
'FORECAST_GENERATION_PER_DAY': QuotaLimits.FORECAST_GENERATION_PER_DAY,
'DATASET_SIZE_ROWS': QuotaLimits.DATASET_SIZE_ROWS,
'FORECAST_HORIZON_DAYS': QuotaLimits.FORECAST_HORIZON_DAYS,
'HISTORICAL_DATA_ACCESS_DAYS': QuotaLimits.HISTORICAL_DATA_ACCESS_DAYS,
'BULK_IMPORT_ROWS': QuotaLimits.BULK_IMPORT_ROWS,
'BULK_EXPORT_ROWS': QuotaLimits.BULK_EXPORT_ROWS,
'POS_SYNC_INTERVAL_MINUTES': QuotaLimits.POS_SYNC_INTERVAL_MINUTES,
'API_CALLS_PER_HOUR': QuotaLimits.API_CALLS_PER_HOUR,
'WEBHOOK_ENDPOINTS': QuotaLimits.WEBHOOK_ENDPOINTS,
'FILE_STORAGE_GB': QuotaLimits.FILE_STORAGE_GB,
'REPORT_RETENTION_DAYS': QuotaLimits.REPORT_RETENTION_DAYS,
}
quotas = quota_map.get(quota_type, {})
return quotas.get(tier_enum)
# ============================================================================
# FEATURE ACCESS CONFIGURATION
# ============================================================================
class PlanFeatures:
"""
Feature availability by subscription tier
Each tier includes all features from lower tiers
"""
# ===== Core Features (All Tiers) =====
CORE_FEATURES = [
'inventory_management',
'sales_tracking',
'basic_recipes',
'production_planning',
'basic_reporting',
'mobile_app_access',
'email_support',
'easy_step_by_step_onboarding', # NEW: Value-add onboarding
]
# ===== Starter Tier Features =====
STARTER_FEATURES = CORE_FEATURES + [
'basic_forecasting',
'demand_prediction',
'waste_tracking',
'order_management',
'customer_management',
'supplier_management',
'batch_tracking',
'expiry_alerts',
]
# ===== Professional Tier Features =====
PROFESSIONAL_FEATURES = STARTER_FEATURES + [
# Advanced Analytics
'advanced_analytics',
'custom_reports',
'sales_analytics',
'supplier_performance',
'waste_analysis',
'profitability_analysis',
# External Data Integration
'weather_data_integration',
'traffic_data_integration',
# Multi-location
'multi_location_support',
'location_comparison',
'inventory_transfer',
# Advanced Forecasting
'batch_scaling',
'recipe_feasibility_check',
'seasonal_patterns',
'longer_forecast_horizon',
# Integration
'pos_integration',
'accounting_export',
'basic_api_access',
# Support
'priority_email_support',
'phone_support',
]
# ===== Enterprise Tier Features =====
ENTERPRISE_FEATURES = PROFESSIONAL_FEATURES + [
# Advanced ML & AI
'scenario_modeling',
'what_if_analysis',
'risk_assessment',
'advanced_ml_parameters',
'model_artifacts_access',
'custom_algorithms',
# Advanced Integration
'full_api_access',
'unlimited_webhooks',
'erp_integration',
'custom_integrations',
# Enterprise Features
'multi_tenant_management',
'white_label_option',
'custom_branding',
'sso_saml',
'advanced_permissions',
'audit_logs_export',
'compliance_reports',
# Advanced Analytics
'benchmarking',
'competitive_analysis',
'market_insights',
'predictive_maintenance',
# Premium Support
'dedicated_account_manager',
'priority_support',
'24_7_support',
'custom_training',
'onsite_support', # Optional add-on
]
@staticmethod
def get_features(tier: str) -> List[str]:
"""Get all features for a tier"""
tier_enum = SubscriptionTier(tier.lower())
feature_map = {
SubscriptionTier.STARTER: PlanFeatures.STARTER_FEATURES,
SubscriptionTier.PROFESSIONAL: PlanFeatures.PROFESSIONAL_FEATURES,
SubscriptionTier.ENTERPRISE: PlanFeatures.ENTERPRISE_FEATURES,
}
return feature_map.get(tier_enum, PlanFeatures.CORE_FEATURES)
@staticmethod
def has_feature(tier: str, feature: str) -> bool:
"""Check if a tier has access to a feature"""
features = PlanFeatures.get_features(tier)
return feature in features
@staticmethod
def requires_professional_tier(feature: str) -> bool:
"""Check if feature requires Professional+ tier"""
return (
feature not in PlanFeatures.STARTER_FEATURES and
feature in PlanFeatures.PROFESSIONAL_FEATURES
)
@staticmethod
def requires_enterprise_tier(feature: str) -> bool:
"""Check if feature requires Enterprise tier"""
return (
feature not in PlanFeatures.PROFESSIONAL_FEATURES and
feature in PlanFeatures.ENTERPRISE_FEATURES
)
# ============================================================================
# SUBSCRIPTION PLAN METADATA
# ============================================================================
class SubscriptionPlanMetadata:
"""Complete metadata for each subscription plan"""
PLANS = {
SubscriptionTier.STARTER: {
"name": "Starter",
"description": "Perfect for small bakeries getting started",
"tagline": "Essential tools for small operations",
"popular": False,
"monthly_price": PlanPricing.MONTHLY_PRICES[SubscriptionTier.STARTER],
"yearly_price": PlanPricing.YEARLY_PRICES[SubscriptionTier.STARTER],
"trial_days": 14,
"features": PlanFeatures.STARTER_FEATURES,
"limits": {
"users": QuotaLimits.MAX_USERS[SubscriptionTier.STARTER],
"locations": QuotaLimits.MAX_LOCATIONS[SubscriptionTier.STARTER],
"products": QuotaLimits.MAX_PRODUCTS[SubscriptionTier.STARTER],
"forecasts_per_day": QuotaLimits.FORECAST_GENERATION_PER_DAY[SubscriptionTier.STARTER],
},
"support": "Email support (48h response)",
"recommended_for": "Single location, up to 5 team members",
},
SubscriptionTier.PROFESSIONAL: {
"name": "Professional",
"description": "For growing bakeries with multiple locations",
"tagline": "Advanced features & analytics",
"popular": True, # Most popular plan
"monthly_price": PlanPricing.MONTHLY_PRICES[SubscriptionTier.PROFESSIONAL],
"yearly_price": PlanPricing.YEARLY_PRICES[SubscriptionTier.PROFESSIONAL],
"trial_days": 14,
"features": PlanFeatures.PROFESSIONAL_FEATURES,
"limits": {
"users": QuotaLimits.MAX_USERS[SubscriptionTier.PROFESSIONAL],
"locations": QuotaLimits.MAX_LOCATIONS[SubscriptionTier.PROFESSIONAL],
"products": QuotaLimits.MAX_PRODUCTS[SubscriptionTier.PROFESSIONAL],
"forecasts_per_day": QuotaLimits.FORECAST_GENERATION_PER_DAY[SubscriptionTier.PROFESSIONAL],
},
"support": "Priority email + phone support (24h response)",
"recommended_for": "Multi-location operations, up to 20 team members",
},
SubscriptionTier.ENTERPRISE: {
"name": "Enterprise",
"description": "For large bakery chains and franchises",
"tagline": "Unlimited scale & custom solutions",
"popular": False,
"monthly_price": PlanPricing.MONTHLY_PRICES[SubscriptionTier.ENTERPRISE],
"yearly_price": PlanPricing.YEARLY_PRICES[SubscriptionTier.ENTERPRISE],
"trial_days": 30,
"features": PlanFeatures.ENTERPRISE_FEATURES,
"limits": {
"users": "Unlimited",
"locations": "Unlimited",
"products": "Unlimited",
"forecasts_per_day": "Unlimited",
},
"support": "24/7 dedicated support + account manager",
"recommended_for": "Enterprise operations, unlimited scale",
"custom_pricing": True,
"contact_sales": True,
},
}
@staticmethod
def get_plan_info(tier: str) -> Dict[str, Any]:
"""Get complete plan information"""
tier_enum = SubscriptionTier(tier.lower())
return SubscriptionPlanMetadata.PLANS.get(tier_enum, {})
@staticmethod
def get_all_plans() -> Dict[SubscriptionTier, Dict[str, Any]]:
"""Get information for all plans"""
return SubscriptionPlanMetadata.PLANS
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
def get_training_job_quota(tier: str) -> Optional[int]:
"""Get training job daily quota for tier"""
return QuotaLimits.get_limit('TRAINING_JOBS_PER_DAY', tier)
def get_forecast_quota(tier: str) -> Optional[int]:
"""Get forecast generation daily quota for tier"""
return QuotaLimits.get_limit('FORECAST_GENERATION_PER_DAY', tier)
def get_dataset_size_limit(tier: str) -> Optional[int]:
"""Get dataset size limit for tier"""
return QuotaLimits.get_limit('DATASET_SIZE_ROWS', tier)
def get_forecast_horizon_limit(tier: str) -> int:
"""Get forecast horizon limit for tier"""
return QuotaLimits.get_limit('FORECAST_HORIZON_DAYS', tier) or 7
def get_historical_data_limit(tier: str) -> Optional[int]:
"""Get historical data access limit for tier"""
return QuotaLimits.get_limit('HISTORICAL_DATA_ACCESS_DAYS', tier)
def can_access_feature(tier: str, feature: str) -> bool:
"""Check if tier can access a feature"""
return PlanFeatures.has_feature(tier, feature)
def get_tier_comparison() -> Dict[str, Any]:
"""
Get feature comparison across all tiers
Useful for pricing pages
"""
return {
"tiers": ["starter", "professional", "enterprise"],
"features": {
"core": PlanFeatures.CORE_FEATURES,
"starter_only": list(set(PlanFeatures.STARTER_FEATURES) - set(PlanFeatures.CORE_FEATURES)),
"professional_only": list(set(PlanFeatures.PROFESSIONAL_FEATURES) - set(PlanFeatures.STARTER_FEATURES)),
"enterprise_only": list(set(PlanFeatures.ENTERPRISE_FEATURES) - set(PlanFeatures.PROFESSIONAL_FEATURES)),
},
"pricing": {
tier.value: {
"monthly": float(PlanPricing.MONTHLY_PRICES[tier]),
"yearly": float(PlanPricing.YEARLY_PRICES[tier]),
}
for tier in SubscriptionTier
},
}