Add role-based filtering and imporve code
This commit is contained in:
@@ -14,6 +14,7 @@ from urllib.parse import urljoin
|
||||
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from shared.config.base import BaseServiceSettings
|
||||
from shared.clients.circuit_breaker import CircuitBreaker, CircuitBreakerOpenException
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -91,11 +92,19 @@ class BaseServiceClient(ABC):
|
||||
self.config = config
|
||||
self.gateway_url = config.GATEWAY_URL
|
||||
self.authenticator = ServiceAuthenticator(service_name, config)
|
||||
|
||||
|
||||
# HTTP client configuration
|
||||
self.timeout = config.HTTP_TIMEOUT
|
||||
self.retries = config.HTTP_RETRIES
|
||||
self.retry_delay = config.HTTP_RETRY_DELAY
|
||||
|
||||
# Circuit breaker for fault tolerance
|
||||
self.circuit_breaker = CircuitBreaker(
|
||||
service_name=f"{service_name}-client",
|
||||
failure_threshold=5,
|
||||
timeout=60,
|
||||
success_threshold=2
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_service_base_path(self) -> str:
|
||||
@@ -113,8 +122,8 @@ class BaseServiceClient(ABC):
|
||||
timeout: Optional[Union[int, httpx.Timeout]] = None
|
||||
) -> Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]:
|
||||
"""
|
||||
Make an authenticated request to another service via gateway
|
||||
|
||||
Make an authenticated request to another service via gateway with circuit breaker protection.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, PUT, DELETE)
|
||||
endpoint: API endpoint (will be prefixed with service base path)
|
||||
@@ -123,10 +132,53 @@ class BaseServiceClient(ABC):
|
||||
params: Query parameters
|
||||
headers: Additional headers
|
||||
timeout: Request timeout override
|
||||
|
||||
|
||||
Returns:
|
||||
Response data or None if request failed
|
||||
"""
|
||||
try:
|
||||
# Wrap request in circuit breaker
|
||||
return await self.circuit_breaker.call(
|
||||
self._do_request,
|
||||
method,
|
||||
endpoint,
|
||||
tenant_id,
|
||||
data,
|
||||
params,
|
||||
headers,
|
||||
timeout
|
||||
)
|
||||
except CircuitBreakerOpenException as e:
|
||||
logger.error(
|
||||
"Circuit breaker open - request rejected",
|
||||
service=self.service_name,
|
||||
endpoint=endpoint,
|
||||
error=str(e)
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error in request",
|
||||
service=self.service_name,
|
||||
endpoint=endpoint,
|
||||
error=str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
async def _do_request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
tenant_id: Optional[str] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[Union[int, httpx.Timeout]] = None
|
||||
) -> Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]:
|
||||
"""
|
||||
Internal method to execute HTTP request with retries.
|
||||
Called by _make_request through circuit breaker.
|
||||
"""
|
||||
try:
|
||||
# Get service token
|
||||
token = await self.authenticator.get_service_token()
|
||||
@@ -135,7 +187,11 @@ class BaseServiceClient(ABC):
|
||||
request_headers = self.authenticator.get_request_headers(tenant_id)
|
||||
request_headers["Authorization"] = f"Bearer {token}"
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
|
||||
|
||||
# Propagate request ID for distributed tracing if provided
|
||||
if headers and "X-Request-ID" in headers:
|
||||
request_headers["X-Request-ID"] = headers["X-Request-ID"]
|
||||
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
|
||||
215
shared/clients/circuit_breaker.py
Normal file
215
shared/clients/circuit_breaker.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Circuit Breaker implementation for inter-service communication
|
||||
Prevents cascading failures by failing fast when a service is unhealthy
|
||||
"""
|
||||
|
||||
import time
|
||||
import structlog
|
||||
from enum import Enum
|
||||
from typing import Callable, Any, Optional
|
||||
import asyncio
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker states"""
|
||||
CLOSED = "closed" # Normal operation, requests pass through
|
||||
OPEN = "open" # Service is failing, reject requests immediately
|
||||
HALF_OPEN = "half_open" # Testing if service has recovered
|
||||
|
||||
|
||||
class CircuitBreakerOpenException(Exception):
|
||||
"""Raised when circuit breaker is open and rejects a request"""
|
||||
pass
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Circuit breaker pattern implementation for preventing cascading failures.
|
||||
|
||||
States:
|
||||
- CLOSED: Normal operation, all requests pass through
|
||||
- OPEN: Service is failing, reject all requests immediately
|
||||
- HALF_OPEN: Testing recovery, allow one request through
|
||||
|
||||
Transitions:
|
||||
- CLOSED -> OPEN: After failure_threshold consecutive failures
|
||||
- OPEN -> HALF_OPEN: After timeout seconds have passed
|
||||
- HALF_OPEN -> CLOSED: If test request succeeds
|
||||
- HALF_OPEN -> OPEN: If test request fails
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_name: str,
|
||||
failure_threshold: int = 5,
|
||||
timeout: int = 60,
|
||||
success_threshold: int = 2
|
||||
):
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
service_name: Name of the service being protected
|
||||
failure_threshold: Number of consecutive failures before opening circuit
|
||||
timeout: Seconds to wait before attempting recovery (half-open state)
|
||||
success_threshold: Consecutive successes needed to close from half-open
|
||||
"""
|
||||
self.service_name = service_name
|
||||
self.failure_threshold = failure_threshold
|
||||
self.timeout = timeout
|
||||
self.success_threshold = success_threshold
|
||||
|
||||
self.state = CircuitState.CLOSED
|
||||
self.failure_count = 0
|
||||
self.success_count = 0
|
||||
self.last_failure_time: Optional[float] = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
logger.info(
|
||||
"Circuit breaker initialized",
|
||||
service=service_name,
|
||||
failure_threshold=failure_threshold,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
async def call(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Execute function with circuit breaker protection.
|
||||
|
||||
Args:
|
||||
func: Async function to execute
|
||||
*args, **kwargs: Arguments to pass to func
|
||||
|
||||
Returns:
|
||||
Result from func
|
||||
|
||||
Raises:
|
||||
CircuitBreakerOpenException: If circuit is open
|
||||
Exception: Any exception raised by func
|
||||
"""
|
||||
async with self._lock:
|
||||
# Check if circuit should transition to half-open
|
||||
if self.state == CircuitState.OPEN:
|
||||
if self._should_attempt_reset():
|
||||
logger.info(
|
||||
"Circuit breaker transitioning to half-open",
|
||||
service=self.service_name
|
||||
)
|
||||
self.state = CircuitState.HALF_OPEN
|
||||
self.success_count = 0
|
||||
else:
|
||||
# Circuit is open, reject request
|
||||
raise CircuitBreakerOpenException(
|
||||
f"Circuit breaker is OPEN for {self.service_name}. "
|
||||
f"Service will be retried in {self._time_until_retry():.0f} seconds."
|
||||
)
|
||||
|
||||
# Execute function
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
await self._on_success()
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
await self._on_failure(e)
|
||||
raise
|
||||
|
||||
def _should_attempt_reset(self) -> bool:
|
||||
"""Check if enough time has passed to attempt recovery"""
|
||||
if self.last_failure_time is None:
|
||||
return True
|
||||
|
||||
return time.time() - self.last_failure_time >= self.timeout
|
||||
|
||||
def _time_until_retry(self) -> float:
|
||||
"""Calculate seconds until next retry attempt"""
|
||||
if self.last_failure_time is None:
|
||||
return 0.0
|
||||
|
||||
elapsed = time.time() - self.last_failure_time
|
||||
return max(0.0, self.timeout - elapsed)
|
||||
|
||||
async def _on_success(self):
|
||||
"""Handle successful request"""
|
||||
async with self._lock:
|
||||
self.failure_count = 0
|
||||
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
self.success_count += 1
|
||||
logger.debug(
|
||||
"Circuit breaker success in half-open state",
|
||||
service=self.service_name,
|
||||
success_count=self.success_count,
|
||||
success_threshold=self.success_threshold
|
||||
)
|
||||
|
||||
if self.success_count >= self.success_threshold:
|
||||
logger.info(
|
||||
"Circuit breaker closing - service recovered",
|
||||
service=self.service_name
|
||||
)
|
||||
self.state = CircuitState.CLOSED
|
||||
self.success_count = 0
|
||||
|
||||
async def _on_failure(self, exception: Exception):
|
||||
"""Handle failed request"""
|
||||
async with self._lock:
|
||||
self.failure_count += 1
|
||||
self.last_failure_time = time.time()
|
||||
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
logger.warning(
|
||||
"Circuit breaker opening - recovery attempt failed",
|
||||
service=self.service_name,
|
||||
error=str(exception)
|
||||
)
|
||||
self.state = CircuitState.OPEN
|
||||
self.success_count = 0
|
||||
|
||||
elif self.state == CircuitState.CLOSED:
|
||||
logger.warning(
|
||||
"Circuit breaker failure recorded",
|
||||
service=self.service_name,
|
||||
failure_count=self.failure_count,
|
||||
threshold=self.failure_threshold,
|
||||
error=str(exception)
|
||||
)
|
||||
|
||||
if self.failure_count >= self.failure_threshold:
|
||||
logger.error(
|
||||
"Circuit breaker opening - failure threshold reached",
|
||||
service=self.service_name,
|
||||
failure_count=self.failure_count
|
||||
)
|
||||
self.state = CircuitState.OPEN
|
||||
|
||||
def get_state(self) -> str:
|
||||
"""Get current circuit breaker state"""
|
||||
return self.state.value
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if circuit is closed (normal operation)"""
|
||||
return self.state == CircuitState.CLOSED
|
||||
|
||||
def is_open(self) -> bool:
|
||||
"""Check if circuit is open (failing fast)"""
|
||||
return self.state == CircuitState.OPEN
|
||||
|
||||
def is_half_open(self) -> bool:
|
||||
"""Check if circuit is half-open (testing recovery)"""
|
||||
return self.state == CircuitState.HALF_OPEN
|
||||
|
||||
async def reset(self):
|
||||
"""Manually reset circuit breaker to closed state"""
|
||||
async with self._lock:
|
||||
logger.info(
|
||||
"Circuit breaker manually reset",
|
||||
service=self.service_name,
|
||||
previous_state=self.state.value
|
||||
)
|
||||
self.state = CircuitState.CLOSED
|
||||
self.failure_count = 0
|
||||
self.success_count = 0
|
||||
self.last_failure_time = None
|
||||
205
shared/clients/nominatim_client.py
Normal file
205
shared/clients/nominatim_client.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Nominatim Client for geocoding and address search
|
||||
"""
|
||||
|
||||
import structlog
|
||||
import httpx
|
||||
from typing import Optional, List, Dict, Any
|
||||
from shared.config.base import BaseServiceSettings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class NominatimClient:
|
||||
"""
|
||||
Client for Nominatim geocoding service.
|
||||
|
||||
Provides address search and geocoding capabilities for the bakery onboarding flow.
|
||||
"""
|
||||
|
||||
def __init__(self, config: BaseServiceSettings):
|
||||
self.config = config
|
||||
self.nominatim_url = getattr(
|
||||
config,
|
||||
"NOMINATIM_SERVICE_URL",
|
||||
"http://nominatim-service:8080"
|
||||
)
|
||||
self.timeout = 30
|
||||
|
||||
async def search_address(
|
||||
self,
|
||||
query: str,
|
||||
country_codes: str = "es",
|
||||
limit: int = 5,
|
||||
addressdetails: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for addresses matching a query.
|
||||
|
||||
Args:
|
||||
query: Address search query (e.g., "Calle Mayor 1, Madrid")
|
||||
country_codes: Limit search to country codes (default: "es" for Spain)
|
||||
limit: Maximum number of results (default: 5)
|
||||
addressdetails: Include detailed address breakdown (default: True)
|
||||
|
||||
Returns:
|
||||
List of geocoded results with lat, lon, and address details
|
||||
|
||||
Example:
|
||||
results = await nominatim.search_address("Calle Mayor 1, Madrid")
|
||||
if results:
|
||||
lat = results[0]["lat"]
|
||||
lon = results[0]["lon"]
|
||||
display_name = results[0]["display_name"]
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.nominatim_url}/search",
|
||||
params={
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"countrycodes": country_codes,
|
||||
"addressdetails": 1 if addressdetails else 0,
|
||||
"limit": limit
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
results = response.json()
|
||||
logger.info(
|
||||
"Address search completed",
|
||||
query=query,
|
||||
results_count=len(results)
|
||||
)
|
||||
return results
|
||||
else:
|
||||
logger.error(
|
||||
"Nominatim search failed",
|
||||
query=query,
|
||||
status_code=response.status_code,
|
||||
response=response.text
|
||||
)
|
||||
return []
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Nominatim search timeout", query=query)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error("Nominatim search error", query=query, error=str(e))
|
||||
return []
|
||||
|
||||
async def geocode_address(
|
||||
self,
|
||||
street: str,
|
||||
city: str,
|
||||
postal_code: Optional[str] = None,
|
||||
country: str = "Spain"
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Geocode a structured address to coordinates.
|
||||
|
||||
Args:
|
||||
street: Street name and number
|
||||
city: City name
|
||||
postal_code: Optional postal code
|
||||
country: Country name (default: "Spain")
|
||||
|
||||
Returns:
|
||||
Dict with lat, lon, and display_name, or None if not found
|
||||
|
||||
Example:
|
||||
location = await nominatim.geocode_address(
|
||||
street="Calle Mayor 1",
|
||||
city="Madrid",
|
||||
postal_code="28013"
|
||||
)
|
||||
if location:
|
||||
lat, lon = location["lat"], location["lon"]
|
||||
"""
|
||||
# Build structured query
|
||||
query_parts = [street, city]
|
||||
if postal_code:
|
||||
query_parts.append(postal_code)
|
||||
query_parts.append(country)
|
||||
|
||||
query = ", ".join(query_parts)
|
||||
|
||||
results = await self.search_address(query, limit=1)
|
||||
if results:
|
||||
return results[0]
|
||||
return None
|
||||
|
||||
async def reverse_geocode(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Reverse geocode coordinates to an address.
|
||||
|
||||
Args:
|
||||
latitude: Latitude coordinate
|
||||
longitude: Longitude coordinate
|
||||
|
||||
Returns:
|
||||
Dict with address details, or None if not found
|
||||
|
||||
Example:
|
||||
address = await nominatim.reverse_geocode(40.4168, -3.7038)
|
||||
if address:
|
||||
city = address["address"]["city"]
|
||||
street = address["address"]["road"]
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.nominatim_url}/reverse",
|
||||
params={
|
||||
"lat": latitude,
|
||||
"lon": longitude,
|
||||
"format": "json",
|
||||
"addressdetails": 1
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
logger.info(
|
||||
"Reverse geocoding completed",
|
||||
lat=latitude,
|
||||
lon=longitude
|
||||
)
|
||||
return result
|
||||
else:
|
||||
logger.error(
|
||||
"Nominatim reverse geocoding failed",
|
||||
lat=latitude,
|
||||
lon=longitude,
|
||||
status_code=response.status_code
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Reverse geocoding error",
|
||||
lat=latitude,
|
||||
lon=longitude,
|
||||
error=str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Check if Nominatim service is healthy.
|
||||
|
||||
Returns:
|
||||
True if service is responding, False otherwise
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
response = await client.get(f"{self.nominatim_url}/status")
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logger.warning("Nominatim health check failed", error=str(e))
|
||||
return False
|
||||
179
shared/monitoring/tracing.py
Normal file
179
shared/monitoring/tracing.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
OpenTelemetry distributed tracing integration
|
||||
Provides end-to-end request tracking across all services
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Optional
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
||||
from opentelemetry.instrumentation.redis import RedisInstrumentor
|
||||
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
def setup_tracing(
|
||||
app,
|
||||
service_name: str,
|
||||
service_version: str = "1.0.0",
|
||||
jaeger_endpoint: str = "http://jaeger-collector.monitoring:4317"
|
||||
):
|
||||
"""
|
||||
Setup OpenTelemetry distributed tracing for a FastAPI service.
|
||||
|
||||
Automatically instruments:
|
||||
- FastAPI endpoints
|
||||
- HTTPX client requests (inter-service calls)
|
||||
- Redis operations
|
||||
- PostgreSQL/SQLAlchemy queries
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
service_name: Name of the service (e.g., "auth-service")
|
||||
service_version: Version of the service
|
||||
jaeger_endpoint: Jaeger collector gRPC endpoint
|
||||
|
||||
Example:
|
||||
from shared.monitoring.tracing import setup_tracing
|
||||
|
||||
app = FastAPI(title="Auth Service")
|
||||
setup_tracing(app, "auth-service")
|
||||
"""
|
||||
|
||||
try:
|
||||
# Create resource with service information
|
||||
resource = Resource(attributes={
|
||||
SERVICE_NAME: service_name,
|
||||
SERVICE_VERSION: service_version,
|
||||
"deployment.environment": "production"
|
||||
})
|
||||
|
||||
# Configure tracer provider
|
||||
tracer_provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(tracer_provider)
|
||||
|
||||
# Configure OTLP exporter to send to Jaeger
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=jaeger_endpoint,
|
||||
insecure=True # Use TLS in production
|
||||
)
|
||||
|
||||
# Add span processor with batching for performance
|
||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
|
||||
# Auto-instrument FastAPI
|
||||
FastAPIInstrumentor.instrument_app(
|
||||
app,
|
||||
tracer_provider=tracer_provider,
|
||||
excluded_urls="health,metrics" # Don't trace health/metrics endpoints
|
||||
)
|
||||
|
||||
# Auto-instrument HTTPX (inter-service communication)
|
||||
HTTPXClientInstrumentor().instrument(tracer_provider=tracer_provider)
|
||||
|
||||
# Auto-instrument Redis
|
||||
try:
|
||||
RedisInstrumentor().instrument(tracer_provider=tracer_provider)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to instrument Redis: {e}")
|
||||
|
||||
# Auto-instrument PostgreSQL (psycopg2) - skip if not available
|
||||
# Most services use asyncpg instead of psycopg2
|
||||
# try:
|
||||
# Psycopg2Instrumentor().instrument(tracer_provider=tracer_provider)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to instrument Psycopg2: {e}")
|
||||
|
||||
# Auto-instrument SQLAlchemy
|
||||
try:
|
||||
SQLAlchemyInstrumentor().instrument(tracer_provider=tracer_provider)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to instrument SQLAlchemy: {e}")
|
||||
|
||||
logger.info(
|
||||
"Distributed tracing configured",
|
||||
service=service_name,
|
||||
jaeger_endpoint=jaeger_endpoint
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to setup tracing - continuing without it",
|
||||
service=service_name,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
|
||||
def get_current_trace_id() -> Optional[str]:
|
||||
"""
|
||||
Get the current trace ID for correlation with logs.
|
||||
|
||||
Returns:
|
||||
Trace ID as hex string, or None if no active trace
|
||||
"""
|
||||
span = trace.get_current_span()
|
||||
if span and span.get_span_context().is_valid:
|
||||
return format(span.get_span_context().trace_id, '032x')
|
||||
return None
|
||||
|
||||
|
||||
def get_current_span_id() -> Optional[str]:
|
||||
"""
|
||||
Get the current span ID.
|
||||
|
||||
Returns:
|
||||
Span ID as hex string, or None if no active span
|
||||
"""
|
||||
span = trace.get_current_span()
|
||||
if span and span.get_span_context().is_valid:
|
||||
return format(span.get_span_context().span_id, '016x')
|
||||
return None
|
||||
|
||||
|
||||
def add_trace_attributes(**attributes):
|
||||
"""
|
||||
Add custom attributes to the current span.
|
||||
|
||||
Example:
|
||||
add_trace_attributes(
|
||||
user_id="123",
|
||||
tenant_id="abc",
|
||||
operation="user_registration"
|
||||
)
|
||||
"""
|
||||
span = trace.get_current_span()
|
||||
if span and span.get_span_context().is_valid:
|
||||
for key, value in attributes.items():
|
||||
span.set_attribute(key, str(value))
|
||||
|
||||
|
||||
def add_trace_event(name: str, **attributes):
|
||||
"""
|
||||
Add an event to the current span (for important operations).
|
||||
|
||||
Example:
|
||||
add_trace_event("user_authenticated", user_id="123", method="jwt")
|
||||
"""
|
||||
span = trace.get_current_span()
|
||||
if span and span.get_span_context().is_valid:
|
||||
span.add_event(name, attributes)
|
||||
|
||||
|
||||
def record_exception(exception: Exception):
|
||||
"""
|
||||
Record an exception in the current span.
|
||||
|
||||
Args:
|
||||
exception: The exception to record
|
||||
"""
|
||||
span = trace.get_current_span()
|
||||
if span and span.get_span_context().is_valid:
|
||||
span.record_exception(exception)
|
||||
span.set_status(trace.Status(trace.StatusCode.ERROR, str(exception)))
|
||||
33
shared/redis_utils/__init__.py
Normal file
33
shared/redis_utils/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Redis utilities for Bakery-IA platform
|
||||
Provides Redis connection management and rate limiting
|
||||
"""
|
||||
|
||||
from shared.redis_utils.client import (
|
||||
RedisConnectionManager,
|
||||
get_redis_manager,
|
||||
initialize_redis,
|
||||
get_redis_client,
|
||||
close_redis,
|
||||
redis_context,
|
||||
set_with_ttl,
|
||||
get_value,
|
||||
increment_counter,
|
||||
get_keys_pattern
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Connection management
|
||||
"RedisConnectionManager",
|
||||
"get_redis_manager",
|
||||
"initialize_redis",
|
||||
"get_redis_client",
|
||||
"close_redis",
|
||||
"redis_context",
|
||||
|
||||
# Convenience functions
|
||||
"set_with_ttl",
|
||||
"get_value",
|
||||
"increment_counter",
|
||||
"get_keys_pattern",
|
||||
]
|
||||
329
shared/redis_utils/client.py
Normal file
329
shared/redis_utils/client.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""
|
||||
Redis client initialization and connection management
|
||||
Provides standardized Redis connection for all services
|
||||
"""
|
||||
|
||||
import redis.asyncio as redis
|
||||
from typing import Optional
|
||||
import structlog
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RedisConnectionManager:
|
||||
"""
|
||||
Manages Redis connections with connection pooling and error handling
|
||||
Thread-safe singleton pattern for sharing connections across service
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._client: Optional[redis.Redis] = None
|
||||
self._pool: Optional[redis.ConnectionPool] = None
|
||||
self.logger = logger
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
redis_url: str,
|
||||
db: int = 0,
|
||||
max_connections: int = 50,
|
||||
decode_responses: bool = True,
|
||||
retry_on_timeout: bool = True,
|
||||
socket_keepalive: bool = True,
|
||||
health_check_interval: int = 30
|
||||
):
|
||||
"""
|
||||
Initialize Redis connection with pool
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL (redis://[:password]@host:port)
|
||||
db: Database number (0-15)
|
||||
max_connections: Maximum connections in pool
|
||||
decode_responses: Automatically decode responses to strings
|
||||
retry_on_timeout: Retry on timeout errors
|
||||
socket_keepalive: Enable TCP keepalive
|
||||
health_check_interval: Health check interval in seconds
|
||||
"""
|
||||
try:
|
||||
# Create connection pool
|
||||
self._pool = redis.ConnectionPool.from_url(
|
||||
redis_url,
|
||||
db=db,
|
||||
max_connections=max_connections,
|
||||
decode_responses=decode_responses,
|
||||
retry_on_timeout=retry_on_timeout,
|
||||
socket_keepalive=socket_keepalive,
|
||||
health_check_interval=health_check_interval
|
||||
)
|
||||
|
||||
# Create Redis client with pool
|
||||
self._client = redis.Redis(connection_pool=self._pool)
|
||||
|
||||
# Test connection
|
||||
await self._client.ping()
|
||||
|
||||
self.logger.info(
|
||||
"redis_initialized",
|
||||
redis_url=redis_url.split("@")[-1], # Log only host:port, not password
|
||||
db=db,
|
||||
max_connections=max_connections
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"redis_initialization_failed",
|
||||
error=str(e),
|
||||
redis_url=redis_url.split("@")[-1]
|
||||
)
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""Close Redis connection and pool"""
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self.logger.info("redis_client_closed")
|
||||
|
||||
if self._pool:
|
||||
await self._pool.disconnect()
|
||||
self.logger.info("redis_pool_closed")
|
||||
|
||||
def get_client(self) -> redis.Redis:
|
||||
"""
|
||||
Get Redis client instance
|
||||
|
||||
Returns:
|
||||
Redis client
|
||||
|
||||
Raises:
|
||||
RuntimeError: If client not initialized
|
||||
"""
|
||||
if self._client is None:
|
||||
raise RuntimeError("Redis client not initialized. Call initialize() first.")
|
||||
return self._client
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Check Redis connection health
|
||||
|
||||
Returns:
|
||||
bool: True if healthy, False otherwise
|
||||
"""
|
||||
try:
|
||||
if self._client is None:
|
||||
return False
|
||||
|
||||
await self._client.ping()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("redis_health_check_failed", error=str(e))
|
||||
return False
|
||||
|
||||
async def get_info(self) -> dict:
|
||||
"""
|
||||
Get Redis server information
|
||||
|
||||
Returns:
|
||||
dict: Redis INFO command output
|
||||
"""
|
||||
try:
|
||||
if self._client is None:
|
||||
return {}
|
||||
|
||||
return await self._client.info()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("redis_info_failed", error=str(e))
|
||||
return {}
|
||||
|
||||
async def flush_db(self):
|
||||
"""
|
||||
Flush current database (USE WITH CAUTION)
|
||||
Only for development/testing
|
||||
"""
|
||||
try:
|
||||
if self._client is None:
|
||||
raise RuntimeError("Redis client not initialized")
|
||||
|
||||
await self._client.flushdb()
|
||||
self.logger.warning("redis_database_flushed")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("redis_flush_failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
# Global connection manager instance
|
||||
_redis_manager: Optional[RedisConnectionManager] = None
|
||||
|
||||
|
||||
async def get_redis_manager() -> RedisConnectionManager:
|
||||
"""
|
||||
Get or create global Redis manager instance
|
||||
|
||||
Returns:
|
||||
RedisConnectionManager instance
|
||||
"""
|
||||
global _redis_manager
|
||||
if _redis_manager is None:
|
||||
_redis_manager = RedisConnectionManager()
|
||||
return _redis_manager
|
||||
|
||||
|
||||
async def initialize_redis(
|
||||
redis_url: str,
|
||||
db: int = 0,
|
||||
max_connections: int = 50,
|
||||
**kwargs
|
||||
) -> redis.Redis:
|
||||
"""
|
||||
Initialize Redis and return client
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL
|
||||
db: Database number
|
||||
max_connections: Maximum connections in pool
|
||||
**kwargs: Additional connection parameters
|
||||
|
||||
Returns:
|
||||
Redis client instance
|
||||
"""
|
||||
manager = await get_redis_manager()
|
||||
await manager.initialize(
|
||||
redis_url=redis_url,
|
||||
db=db,
|
||||
max_connections=max_connections,
|
||||
**kwargs
|
||||
)
|
||||
return manager.get_client()
|
||||
|
||||
|
||||
async def get_redis_client() -> redis.Redis:
|
||||
"""
|
||||
Get initialized Redis client
|
||||
|
||||
Returns:
|
||||
Redis client instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Redis not initialized
|
||||
"""
|
||||
manager = await get_redis_manager()
|
||||
return manager.get_client()
|
||||
|
||||
|
||||
async def close_redis():
|
||||
"""Close Redis connections"""
|
||||
global _redis_manager
|
||||
if _redis_manager:
|
||||
await _redis_manager.close()
|
||||
_redis_manager = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def redis_context(redis_url: str, db: int = 0):
|
||||
"""
|
||||
Context manager for Redis connections
|
||||
|
||||
Usage:
|
||||
async with redis_context(settings.REDIS_URL) as client:
|
||||
await client.set("key", "value")
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL
|
||||
db: Database number
|
||||
|
||||
Yields:
|
||||
Redis client
|
||||
"""
|
||||
client = None
|
||||
try:
|
||||
client = await initialize_redis(redis_url, db=db)
|
||||
yield client
|
||||
finally:
|
||||
if client:
|
||||
await close_redis()
|
||||
|
||||
|
||||
# Convenience functions for common operations
|
||||
async def set_with_ttl(key: str, value: str, ttl: int) -> bool:
|
||||
"""
|
||||
Set key with TTL
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
value: Value to set
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
try:
|
||||
client = await get_redis_client()
|
||||
await client.setex(key, ttl, value)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("redis_set_failed", key=key, error=str(e))
|
||||
return False
|
||||
|
||||
|
||||
async def get_value(key: str) -> Optional[str]:
|
||||
"""
|
||||
Get value by key
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
|
||||
Returns:
|
||||
Value or None if not found
|
||||
"""
|
||||
try:
|
||||
client = await get_redis_client()
|
||||
return await client.get(key)
|
||||
except Exception as e:
|
||||
logger.error("redis_get_failed", key=key, error=str(e))
|
||||
return None
|
||||
|
||||
|
||||
async def increment_counter(key: str, amount: int = 1, ttl: Optional[int] = None) -> int:
|
||||
"""
|
||||
Increment counter with optional TTL
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
amount: Amount to increment
|
||||
ttl: Time to live in seconds (sets on first increment)
|
||||
|
||||
Returns:
|
||||
New counter value
|
||||
"""
|
||||
try:
|
||||
client = await get_redis_client()
|
||||
new_value = await client.incrby(key, amount)
|
||||
|
||||
# Set TTL if specified and key is new (value == amount)
|
||||
if ttl and new_value == amount:
|
||||
await client.expire(key, ttl)
|
||||
|
||||
return new_value
|
||||
except Exception as e:
|
||||
logger.error("redis_increment_failed", key=key, error=str(e))
|
||||
return 0
|
||||
|
||||
|
||||
async def get_keys_pattern(pattern: str) -> list:
|
||||
"""
|
||||
Get keys matching pattern
|
||||
|
||||
Args:
|
||||
pattern: Redis key pattern (e.g., "quota:*")
|
||||
|
||||
Returns:
|
||||
List of matching keys
|
||||
"""
|
||||
try:
|
||||
client = await get_redis_client()
|
||||
return await client.keys(pattern)
|
||||
except Exception as e:
|
||||
logger.error("redis_keys_failed", pattern=pattern, error=str(e))
|
||||
return []
|
||||
9
shared/requirements-tracing.txt
Normal file
9
shared/requirements-tracing.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
# OpenTelemetry dependencies for distributed tracing
|
||||
opentelemetry-api==1.21.0
|
||||
opentelemetry-sdk==1.21.0
|
||||
opentelemetry-instrumentation-fastapi==0.42b0
|
||||
opentelemetry-instrumentation-httpx==0.42b0
|
||||
opentelemetry-instrumentation-redis==0.42b0
|
||||
# opentelemetry-instrumentation-psycopg2==0.42b0 # Commented out - not all services use psycopg2
|
||||
opentelemetry-instrumentation-sqlalchemy==0.42b0
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.21.0
|
||||
31
shared/security/__init__.py
Normal file
31
shared/security/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Security utilities for RBAC, audit logging, and rate limiting
|
||||
"""
|
||||
|
||||
from shared.security.audit_logger import (
|
||||
AuditLogger,
|
||||
AuditSeverity,
|
||||
AuditAction,
|
||||
create_audit_logger,
|
||||
create_audit_log_model
|
||||
)
|
||||
|
||||
from shared.security.rate_limiter import (
|
||||
RateLimiter,
|
||||
QuotaType,
|
||||
create_rate_limiter
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Audit logging
|
||||
"AuditLogger",
|
||||
"AuditSeverity",
|
||||
"AuditAction",
|
||||
"create_audit_logger",
|
||||
"create_audit_log_model",
|
||||
|
||||
# Rate limiting
|
||||
"RateLimiter",
|
||||
"QuotaType",
|
||||
"create_rate_limiter",
|
||||
]
|
||||
317
shared/security/audit_logger.py
Normal file
317
shared/security/audit_logger.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Audit logging system for tracking critical operations across all services
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any
|
||||
from enum import Enum
|
||||
import structlog
|
||||
from sqlalchemy import Column, String, DateTime, Text, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class AuditSeverity(str, Enum):
|
||||
"""Severity levels for audit events"""
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class AuditAction(str, Enum):
|
||||
"""Common audit action types"""
|
||||
CREATE = "create"
|
||||
READ = "read"
|
||||
UPDATE = "update"
|
||||
DELETE = "delete"
|
||||
APPROVE = "approve"
|
||||
REJECT = "reject"
|
||||
CANCEL = "cancel"
|
||||
EXPORT = "export"
|
||||
IMPORT = "import"
|
||||
INVITE = "invite"
|
||||
REMOVE = "remove"
|
||||
UPGRADE = "upgrade"
|
||||
DOWNGRADE = "downgrade"
|
||||
DEACTIVATE = "deactivate"
|
||||
ACTIVATE = "activate"
|
||||
|
||||
|
||||
def create_audit_log_model(Base):
|
||||
"""
|
||||
Factory function to create AuditLog model for any service
|
||||
Each service has its own audit_logs table in their database
|
||||
|
||||
Usage in service models/__init__.py:
|
||||
from shared.database.base import Base
|
||||
from shared.security import create_audit_log_model
|
||||
|
||||
AuditLog = create_audit_log_model(Base)
|
||||
|
||||
Args:
|
||||
Base: SQLAlchemy declarative base for the service
|
||||
|
||||
Returns:
|
||||
AuditLog model class bound to the service's Base
|
||||
"""
|
||||
|
||||
class AuditLog(Base):
|
||||
"""
|
||||
Audit log model for tracking critical operations
|
||||
Each service has its own audit_logs table for data locality
|
||||
"""
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
# Primary identification
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Tenant and user context
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Action details
|
||||
action = Column(String(100), nullable=False, index=True) # create, update, delete, etc.
|
||||
resource_type = Column(String(100), nullable=False, index=True) # supplier, recipe, order, etc.
|
||||
resource_id = Column(String(255), nullable=True, index=True)
|
||||
|
||||
# Severity and categorization
|
||||
severity = Column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
default="medium",
|
||||
index=True
|
||||
) # low, medium, high, critical
|
||||
|
||||
# Service identification
|
||||
service_name = Column(String(100), nullable=False, index=True)
|
||||
|
||||
# Details
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Audit trail data
|
||||
changes = Column(JSON, nullable=True) # Before/after values for updates
|
||||
audit_metadata = Column(JSON, nullable=True) # Additional context
|
||||
|
||||
# Request context
|
||||
ip_address = Column(String(45), nullable=True) # IPv4 or IPv6
|
||||
user_agent = Column(Text, nullable=True)
|
||||
endpoint = Column(String(255), nullable=True)
|
||||
method = Column(String(10), nullable=True) # GET, POST, PUT, DELETE
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
index=True
|
||||
)
|
||||
|
||||
# Composite indexes for common query patterns
|
||||
__table_args__ = (
|
||||
Index('idx_audit_tenant_created', 'tenant_id', 'created_at'),
|
||||
Index('idx_audit_user_created', 'user_id', 'created_at'),
|
||||
Index('idx_audit_resource_type_action', 'resource_type', 'action'),
|
||||
Index('idx_audit_severity_created', 'severity', 'created_at'),
|
||||
Index('idx_audit_service_created', 'service_name', 'created_at'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<AuditLog(id={self.id}, tenant={self.tenant_id}, "
|
||||
f"action={self.action}, resource={self.resource_type}, "
|
||||
f"severity={self.severity})>"
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert audit log to dictionary"""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"tenant_id": str(self.tenant_id),
|
||||
"user_id": str(self.user_id),
|
||||
"action": self.action,
|
||||
"resource_type": self.resource_type,
|
||||
"resource_id": self.resource_id,
|
||||
"severity": self.severity,
|
||||
"service_name": self.service_name,
|
||||
"description": self.description,
|
||||
"changes": self.changes,
|
||||
"metadata": self.audit_metadata,
|
||||
"ip_address": self.ip_address,
|
||||
"user_agent": self.user_agent,
|
||||
"endpoint": self.endpoint,
|
||||
"method": self.method,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
return AuditLog
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""Service for logging audit events"""
|
||||
|
||||
def __init__(self, service_name: str):
|
||||
self.service_name = service_name
|
||||
self.logger = logger.bind(service=service_name)
|
||||
|
||||
async def log_event(
|
||||
self,
|
||||
db_session,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
action: str,
|
||||
resource_type: str,
|
||||
resource_id: Optional[str] = None,
|
||||
severity: str = "medium",
|
||||
description: Optional[str] = None,
|
||||
changes: Optional[Dict[str, Any]] = None,
|
||||
audit_metadata: Optional[Dict[str, Any]] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
method: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Log an audit event
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
tenant_id: Tenant ID
|
||||
user_id: User ID who performed the action
|
||||
action: Action performed (create, update, delete, etc.)
|
||||
resource_type: Type of resource (user, sale, recipe, etc.)
|
||||
resource_id: ID of the resource affected
|
||||
severity: Severity level (low, medium, high, critical)
|
||||
description: Human-readable description
|
||||
changes: Dictionary of before/after values for updates
|
||||
audit_metadata: Additional context
|
||||
endpoint: API endpoint
|
||||
method: HTTP method
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
"""
|
||||
try:
|
||||
audit_log = AuditLog(
|
||||
tenant_id=uuid.UUID(tenant_id) if isinstance(tenant_id, str) else tenant_id,
|
||||
user_id=uuid.UUID(user_id) if isinstance(user_id, str) else user_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
severity=severity,
|
||||
service_name=self.service_name,
|
||||
description=description,
|
||||
changes=changes,
|
||||
audit_metadata=audit_metadata,
|
||||
endpoint=endpoint,
|
||||
method=method,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
db_session.add(audit_log)
|
||||
await db_session.commit()
|
||||
|
||||
self.logger.info(
|
||||
"audit_event_logged",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=str(user_id),
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
severity=severity,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"audit_log_failed",
|
||||
error=str(e),
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=str(user_id),
|
||||
action=action,
|
||||
)
|
||||
# Don't raise - audit logging should not block operations
|
||||
|
||||
async def log_deletion(
|
||||
self,
|
||||
db_session,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
resource_type: str,
|
||||
resource_id: str,
|
||||
resource_data: Optional[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Convenience method for logging deletions"""
|
||||
return await self.log_event(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
action=AuditAction.DELETE.value,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
severity=AuditSeverity.HIGH.value,
|
||||
description=f"Deleted {resource_type} {resource_id}",
|
||||
audit_metadata={"deleted_data": resource_data} if resource_data else None,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def log_role_change(
|
||||
self,
|
||||
db_session,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
target_user_id: str,
|
||||
old_role: str,
|
||||
new_role: str,
|
||||
**kwargs
|
||||
):
|
||||
"""Convenience method for logging role changes"""
|
||||
return await self.log_event(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
action=AuditAction.UPDATE.value,
|
||||
resource_type="user_role",
|
||||
resource_id=target_user_id,
|
||||
severity=AuditSeverity.HIGH.value,
|
||||
description=f"Changed user role from {old_role} to {new_role}",
|
||||
changes={
|
||||
"before": {"role": old_role},
|
||||
"after": {"role": new_role}
|
||||
},
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def log_subscription_change(
|
||||
self,
|
||||
db_session,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
action: str,
|
||||
old_plan: Optional[str] = None,
|
||||
new_plan: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Convenience method for logging subscription changes"""
|
||||
return await self.log_event(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
resource_type="subscription",
|
||||
resource_id=tenant_id,
|
||||
severity=AuditSeverity.CRITICAL.value,
|
||||
description=f"Subscription {action}: {old_plan} -> {new_plan}" if old_plan else f"Subscription {action}: {new_plan}",
|
||||
changes={
|
||||
"before": {"plan": old_plan} if old_plan else None,
|
||||
"after": {"plan": new_plan} if new_plan else None
|
||||
},
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
def create_audit_logger(service_name: str) -> AuditLogger:
|
||||
"""Factory function to create audit logger for a service"""
|
||||
return AuditLogger(service_name)
|
||||
388
shared/security/rate_limiter.py
Normal file
388
shared/security/rate_limiter.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
Rate limiting and quota management system for subscription-based features
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import structlog
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class QuotaType(str, Enum):
|
||||
"""Types of quotas"""
|
||||
FORECAST_GENERATION = "forecast_generation"
|
||||
TRAINING_JOBS = "training_jobs"
|
||||
BULK_IMPORTS = "bulk_imports"
|
||||
POS_SYNC = "pos_sync"
|
||||
API_CALLS = "api_calls"
|
||||
DEMO_SESSIONS = "demo_sessions"
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
Redis-based rate limiter for subscription tier quotas
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client):
|
||||
"""
|
||||
Initialize rate limiter
|
||||
|
||||
Args:
|
||||
redis_client: Redis client for storing quota counters
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.logger = logger
|
||||
|
||||
def _get_quota_key(self, tenant_id: str, quota_type: str, period: str = "daily") -> str:
|
||||
"""Generate Redis key for quota tracking"""
|
||||
date_str = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
return f"quota:{period}:{quota_type}:{tenant_id}:{date_str}"
|
||||
|
||||
def _get_dataset_size_key(self, tenant_id: str) -> str:
|
||||
"""Generate Redis key for dataset size tracking"""
|
||||
return f"dataset_size:{tenant_id}"
|
||||
|
||||
async def check_and_increment_quota(
|
||||
self,
|
||||
tenant_id: str,
|
||||
quota_type: str,
|
||||
limit: Optional[int],
|
||||
period: int = 86400 # 24 hours in seconds
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if quota allows action and increment counter
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
quota_type: Type of quota to check
|
||||
limit: Maximum allowed count (None = unlimited)
|
||||
period: Time period in seconds (default: 24 hours)
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- allowed: bool
|
||||
- current: int (current count)
|
||||
- limit: Optional[int]
|
||||
- reset_at: datetime (when quota resets)
|
||||
|
||||
Raises:
|
||||
HTTPException: If quota is exceeded
|
||||
"""
|
||||
if limit is None:
|
||||
# Unlimited quota
|
||||
return {
|
||||
"allowed": True,
|
||||
"current": 0,
|
||||
"limit": None,
|
||||
"reset_at": None
|
||||
}
|
||||
|
||||
key = self._get_quota_key(tenant_id, quota_type)
|
||||
|
||||
try:
|
||||
# Get current count
|
||||
current = await self.redis.get(key)
|
||||
current_count = int(current) if current else 0
|
||||
|
||||
# Check if limit exceeded
|
||||
if current_count >= limit:
|
||||
ttl = await self.redis.ttl(key)
|
||||
reset_at = datetime.utcnow() + timedelta(seconds=ttl if ttl > 0 else period)
|
||||
|
||||
self.logger.warning(
|
||||
"quota_exceeded",
|
||||
tenant_id=tenant_id,
|
||||
quota_type=quota_type,
|
||||
current=current_count,
|
||||
limit=limit,
|
||||
reset_at=reset_at.isoformat()
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail={
|
||||
"error": "quota_exceeded",
|
||||
"message": f"Daily quota exceeded for {quota_type}",
|
||||
"current": current_count,
|
||||
"limit": limit,
|
||||
"reset_at": reset_at.isoformat(),
|
||||
"quota_type": quota_type
|
||||
}
|
||||
)
|
||||
|
||||
# Increment counter
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.incr(key)
|
||||
pipe.expire(key, period)
|
||||
await pipe.execute()
|
||||
|
||||
new_count = current_count + 1
|
||||
ttl = await self.redis.ttl(key)
|
||||
reset_at = datetime.utcnow() + timedelta(seconds=ttl if ttl > 0 else period)
|
||||
|
||||
self.logger.info(
|
||||
"quota_incremented",
|
||||
tenant_id=tenant_id,
|
||||
quota_type=quota_type,
|
||||
current=new_count,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"current": new_count,
|
||||
"limit": limit,
|
||||
"reset_at": reset_at
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"quota_check_failed",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id,
|
||||
quota_type=quota_type
|
||||
)
|
||||
# Fail open - allow the operation
|
||||
return {
|
||||
"allowed": True,
|
||||
"current": 0,
|
||||
"limit": limit,
|
||||
"reset_at": None
|
||||
}
|
||||
|
||||
async def get_current_usage(
|
||||
self,
|
||||
tenant_id: str,
|
||||
quota_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current quota usage without incrementing
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
quota_type: Type of quota to check
|
||||
|
||||
Returns:
|
||||
Dict with current usage information
|
||||
"""
|
||||
key = self._get_quota_key(tenant_id, quota_type)
|
||||
|
||||
try:
|
||||
current = await self.redis.get(key)
|
||||
current_count = int(current) if current else 0
|
||||
ttl = await self.redis.ttl(key)
|
||||
reset_at = datetime.utcnow() + timedelta(seconds=ttl) if ttl > 0 else None
|
||||
|
||||
return {
|
||||
"current": current_count,
|
||||
"reset_at": reset_at
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"usage_check_failed",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id,
|
||||
quota_type=quota_type
|
||||
)
|
||||
return {
|
||||
"current": 0,
|
||||
"reset_at": None
|
||||
}
|
||||
|
||||
async def reset_quota(self, tenant_id: str, quota_type: str):
|
||||
"""
|
||||
Reset quota for a tenant (admin function)
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
quota_type: Type of quota to reset
|
||||
"""
|
||||
key = self._get_quota_key(tenant_id, quota_type)
|
||||
try:
|
||||
await self.redis.delete(key)
|
||||
self.logger.info(
|
||||
"quota_reset",
|
||||
tenant_id=tenant_id,
|
||||
quota_type=quota_type
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"quota_reset_failed",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id,
|
||||
quota_type=quota_type
|
||||
)
|
||||
|
||||
async def validate_dataset_size(
|
||||
self,
|
||||
tenant_id: str,
|
||||
dataset_rows: int,
|
||||
subscription_tier: str
|
||||
):
|
||||
"""
|
||||
Validate dataset size against subscription tier limits
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
dataset_rows: Number of rows in dataset
|
||||
subscription_tier: User's subscription tier
|
||||
|
||||
Raises:
|
||||
HTTPException: If dataset size exceeds limit
|
||||
"""
|
||||
# Dataset size limits per tier
|
||||
dataset_limits = {
|
||||
'starter': 1000,
|
||||
'professional': 10000,
|
||||
'enterprise': None # Unlimited
|
||||
}
|
||||
|
||||
limit = dataset_limits.get(subscription_tier.lower())
|
||||
|
||||
if limit is not None and dataset_rows > limit:
|
||||
self.logger.warning(
|
||||
"dataset_size_exceeded",
|
||||
tenant_id=tenant_id,
|
||||
dataset_rows=dataset_rows,
|
||||
limit=limit,
|
||||
tier=subscription_tier
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"error": "dataset_size_limit_exceeded",
|
||||
"message": f"Dataset size limited to {limit:,} rows for {subscription_tier} tier",
|
||||
"current_size": dataset_rows,
|
||||
"limit": limit,
|
||||
"tier": subscription_tier,
|
||||
"upgrade_url": "/app/settings/profile"
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"dataset_size_validated",
|
||||
tenant_id=tenant_id,
|
||||
dataset_rows=dataset_rows,
|
||||
tier=subscription_tier
|
||||
)
|
||||
|
||||
async def validate_forecast_horizon(
|
||||
self,
|
||||
tenant_id: str,
|
||||
horizon_days: int,
|
||||
subscription_tier: str
|
||||
):
|
||||
"""
|
||||
Validate forecast horizon against subscription tier limits
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
horizon_days: Number of days to forecast
|
||||
subscription_tier: User's subscription tier
|
||||
|
||||
Raises:
|
||||
HTTPException: If horizon exceeds limit
|
||||
"""
|
||||
# Forecast horizon limits per tier
|
||||
horizon_limits = {
|
||||
'starter': 7,
|
||||
'professional': 90,
|
||||
'enterprise': 365 # Practically unlimited
|
||||
}
|
||||
|
||||
limit = horizon_limits.get(subscription_tier.lower(), 7)
|
||||
|
||||
if horizon_days > limit:
|
||||
self.logger.warning(
|
||||
"forecast_horizon_exceeded",
|
||||
tenant_id=tenant_id,
|
||||
horizon_days=horizon_days,
|
||||
limit=limit,
|
||||
tier=subscription_tier
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"error": "forecast_horizon_limit_exceeded",
|
||||
"message": f"Forecast horizon limited to {limit} days for {subscription_tier} tier",
|
||||
"requested_horizon": horizon_days,
|
||||
"limit": limit,
|
||||
"tier": subscription_tier,
|
||||
"upgrade_url": "/app/settings/profile"
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"forecast_horizon_validated",
|
||||
tenant_id=tenant_id,
|
||||
horizon_days=horizon_days,
|
||||
tier=subscription_tier
|
||||
)
|
||||
|
||||
async def validate_historical_data_access(
|
||||
self,
|
||||
tenant_id: str,
|
||||
days_back: int,
|
||||
subscription_tier: str
|
||||
):
|
||||
"""
|
||||
Validate historical data access against subscription tier limits
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
days_back: Number of days of historical data requested
|
||||
subscription_tier: User's subscription tier
|
||||
|
||||
Raises:
|
||||
HTTPException: If historical data access exceeds limit
|
||||
"""
|
||||
# Historical data limits per tier
|
||||
history_limits = {
|
||||
'starter': 7,
|
||||
'professional': 90,
|
||||
'enterprise': None # Unlimited
|
||||
}
|
||||
|
||||
limit = history_limits.get(subscription_tier.lower(), 7)
|
||||
|
||||
if limit is not None and days_back > limit:
|
||||
self.logger.warning(
|
||||
"historical_data_limit_exceeded",
|
||||
tenant_id=tenant_id,
|
||||
days_back=days_back,
|
||||
limit=limit,
|
||||
tier=subscription_tier
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"error": "historical_data_limit_exceeded",
|
||||
"message": f"Historical data limited to {limit} days for {subscription_tier} tier",
|
||||
"requested_days": days_back,
|
||||
"limit": limit,
|
||||
"tier": subscription_tier,
|
||||
"upgrade_url": "/app/settings/profile"
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"historical_data_access_validated",
|
||||
tenant_id=tenant_id,
|
||||
days_back=days_back,
|
||||
tier=subscription_tier
|
||||
)
|
||||
|
||||
|
||||
def create_rate_limiter(redis_client) -> RateLimiter:
|
||||
"""Factory function to create rate limiter"""
|
||||
return RateLimiter(redis_client)
|
||||
@@ -23,6 +23,7 @@ from fastapi.routing import APIRouter
|
||||
from shared.monitoring import setup_logging
|
||||
from shared.monitoring.metrics import setup_metrics_early
|
||||
from shared.monitoring.health_checks import setup_fastapi_health_checks
|
||||
from shared.monitoring.tracing import setup_tracing
|
||||
from shared.database.base import DatabaseManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -51,6 +52,7 @@ class BaseFastAPIService:
|
||||
enable_cors: bool = True,
|
||||
enable_exception_handlers: bool = True,
|
||||
enable_messaging: bool = False,
|
||||
enable_tracing: bool = True,
|
||||
custom_metrics: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
alert_service_class: Optional[type] = None
|
||||
):
|
||||
@@ -69,6 +71,7 @@ class BaseFastAPIService:
|
||||
self.enable_cors = enable_cors
|
||||
self.enable_exception_handlers = enable_exception_handlers
|
||||
self.enable_messaging = enable_messaging
|
||||
self.enable_tracing = enable_tracing
|
||||
self.custom_metrics = custom_metrics or {}
|
||||
self.alert_service_class = alert_service_class
|
||||
|
||||
@@ -106,6 +109,18 @@ class BaseFastAPIService:
|
||||
if self.enable_metrics:
|
||||
self.metrics_collector = setup_metrics_early(self.app, self.service_name)
|
||||
|
||||
# Setup distributed tracing
|
||||
if self.enable_tracing:
|
||||
try:
|
||||
jaeger_endpoint = os.getenv(
|
||||
"JAEGER_COLLECTOR_ENDPOINT",
|
||||
"http://jaeger-collector.monitoring:4317"
|
||||
)
|
||||
setup_tracing(self.app, self.service_name, self.version, jaeger_endpoint)
|
||||
self.logger.info(f"Distributed tracing enabled for {self.service_name}")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to setup tracing, continuing without it: {e}")
|
||||
|
||||
# Setup lifespan
|
||||
self.app.router.lifespan_context = self._create_lifespan()
|
||||
|
||||
|
||||
486
shared/subscription/plans.py
Normal file
486
shared/subscription/plans.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""
|
||||
Centralized Subscription Plan Configuration
|
||||
Owner: Tenant Service
|
||||
Single source of truth for all subscription tiers, quotas, features, and limits
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
class SubscriptionTier(str, Enum):
|
||||
"""Subscription tier enumeration"""
|
||||
STARTER = "starter"
|
||||
PROFESSIONAL = "professional"
|
||||
ENTERPRISE = "enterprise"
|
||||
|
||||
|
||||
class BillingCycle(str, Enum):
|
||||
"""Billing cycle options"""
|
||||
MONTHLY = "monthly"
|
||||
YEARLY = "yearly"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PRICING CONFIGURATION
|
||||
# ============================================================================
|
||||
|
||||
class PlanPricing:
|
||||
"""Pricing for each subscription tier"""
|
||||
|
||||
MONTHLY_PRICES = {
|
||||
SubscriptionTier.STARTER: Decimal("49.00"),
|
||||
SubscriptionTier.PROFESSIONAL: Decimal("149.00"),
|
||||
SubscriptionTier.ENTERPRISE: Decimal("499.00"), # Base price, custom quotes available
|
||||
}
|
||||
|
||||
YEARLY_PRICES = {
|
||||
SubscriptionTier.STARTER: Decimal("490.00"), # ~17% discount (2 months free)
|
||||
SubscriptionTier.PROFESSIONAL: Decimal("1490.00"), # ~17% discount
|
||||
SubscriptionTier.ENTERPRISE: Decimal("4990.00"), # Base price, custom quotes available
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_price(tier: str, billing_cycle: str = "monthly") -> Decimal:
|
||||
"""Get price for tier and billing cycle"""
|
||||
tier_enum = SubscriptionTier(tier.lower())
|
||||
if billing_cycle == "yearly":
|
||||
return PlanPricing.YEARLY_PRICES[tier_enum]
|
||||
return PlanPricing.MONTHLY_PRICES[tier_enum]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# QUOTA LIMITS CONFIGURATION
|
||||
# ============================================================================
|
||||
|
||||
class QuotaLimits:
|
||||
"""
|
||||
Resource quotas and limits for each subscription tier
|
||||
None = Unlimited
|
||||
"""
|
||||
|
||||
# ===== Team & Organization Limits =====
|
||||
MAX_USERS = {
|
||||
SubscriptionTier.STARTER: 5,
|
||||
SubscriptionTier.PROFESSIONAL: 20,
|
||||
SubscriptionTier.ENTERPRISE: None, # Unlimited
|
||||
}
|
||||
|
||||
MAX_LOCATIONS = {
|
||||
SubscriptionTier.STARTER: 1,
|
||||
SubscriptionTier.PROFESSIONAL: 3,
|
||||
SubscriptionTier.ENTERPRISE: None, # Unlimited
|
||||
}
|
||||
|
||||
# ===== Product & Inventory Limits =====
|
||||
MAX_PRODUCTS = {
|
||||
SubscriptionTier.STARTER: 50,
|
||||
SubscriptionTier.PROFESSIONAL: 500,
|
||||
SubscriptionTier.ENTERPRISE: None, # Unlimited
|
||||
}
|
||||
|
||||
MAX_RECIPES = {
|
||||
SubscriptionTier.STARTER: 25,
|
||||
SubscriptionTier.PROFESSIONAL: 250,
|
||||
SubscriptionTier.ENTERPRISE: None, # Unlimited
|
||||
}
|
||||
|
||||
MAX_SUPPLIERS = {
|
||||
SubscriptionTier.STARTER: 10,
|
||||
SubscriptionTier.PROFESSIONAL: 100,
|
||||
SubscriptionTier.ENTERPRISE: None, # Unlimited
|
||||
}
|
||||
|
||||
# ===== ML & Analytics Quotas (Daily Limits) =====
|
||||
TRAINING_JOBS_PER_DAY = {
|
||||
SubscriptionTier.STARTER: 1,
|
||||
SubscriptionTier.PROFESSIONAL: 5,
|
||||
SubscriptionTier.ENTERPRISE: None, # Unlimited
|
||||
}
|
||||
|
||||
FORECAST_GENERATION_PER_DAY = {
|
||||
SubscriptionTier.STARTER: 10,
|
||||
SubscriptionTier.PROFESSIONAL: 100,
|
||||
SubscriptionTier.ENTERPRISE: None, # Unlimited
|
||||
}
|
||||
|
||||
# ===== Data Limits =====
|
||||
DATASET_SIZE_ROWS = {
|
||||
SubscriptionTier.STARTER: 1000,
|
||||
SubscriptionTier.PROFESSIONAL: 10000,
|
||||
SubscriptionTier.ENTERPRISE: None, # Unlimited
|
||||
}
|
||||
|
||||
FORECAST_HORIZON_DAYS = {
|
||||
SubscriptionTier.STARTER: 7,
|
||||
SubscriptionTier.PROFESSIONAL: 90,
|
||||
SubscriptionTier.ENTERPRISE: 365,
|
||||
}
|
||||
|
||||
HISTORICAL_DATA_ACCESS_DAYS = {
|
||||
SubscriptionTier.STARTER: 30, # 1 month
|
||||
SubscriptionTier.PROFESSIONAL: 365, # 1 year
|
||||
SubscriptionTier.ENTERPRISE: None, # Unlimited
|
||||
}
|
||||
|
||||
# ===== Import/Export Limits =====
|
||||
BULK_IMPORT_ROWS = {
|
||||
SubscriptionTier.STARTER: 100,
|
||||
SubscriptionTier.PROFESSIONAL: 1000,
|
||||
SubscriptionTier.ENTERPRISE: 10000,
|
||||
}
|
||||
|
||||
BULK_EXPORT_ROWS = {
|
||||
SubscriptionTier.STARTER: 1000,
|
||||
SubscriptionTier.PROFESSIONAL: 10000,
|
||||
SubscriptionTier.ENTERPRISE: None, # Unlimited
|
||||
}
|
||||
|
||||
# ===== Integration Limits =====
|
||||
POS_SYNC_INTERVAL_MINUTES = {
|
||||
SubscriptionTier.STARTER: 60, # Hourly
|
||||
SubscriptionTier.PROFESSIONAL: 15, # Every 15 minutes
|
||||
SubscriptionTier.ENTERPRISE: 5, # Every 5 minutes (near real-time)
|
||||
}
|
||||
|
||||
API_CALLS_PER_HOUR = {
|
||||
SubscriptionTier.STARTER: 100,
|
||||
SubscriptionTier.PROFESSIONAL: 1000,
|
||||
SubscriptionTier.ENTERPRISE: 10000,
|
||||
}
|
||||
|
||||
WEBHOOK_ENDPOINTS = {
|
||||
SubscriptionTier.STARTER: 2,
|
||||
SubscriptionTier.PROFESSIONAL: 10,
|
||||
SubscriptionTier.ENTERPRISE: None, # Unlimited
|
||||
}
|
||||
|
||||
# ===== Storage Limits =====
|
||||
FILE_STORAGE_GB = {
|
||||
SubscriptionTier.STARTER: 1,
|
||||
SubscriptionTier.PROFESSIONAL: 10,
|
||||
SubscriptionTier.ENTERPRISE: 100,
|
||||
}
|
||||
|
||||
REPORT_RETENTION_DAYS = {
|
||||
SubscriptionTier.STARTER: 30,
|
||||
SubscriptionTier.PROFESSIONAL: 180,
|
||||
SubscriptionTier.ENTERPRISE: 365,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_limit(quota_type: str, tier: str) -> Optional[int]:
|
||||
"""
|
||||
Get quota limit for a specific type and tier
|
||||
|
||||
Args:
|
||||
quota_type: Quota type (e.g., 'MAX_USERS')
|
||||
tier: Subscription tier
|
||||
|
||||
Returns:
|
||||
Optional[int]: Limit value or None for unlimited
|
||||
"""
|
||||
tier_enum = SubscriptionTier(tier.lower())
|
||||
|
||||
quota_map = {
|
||||
'MAX_USERS': QuotaLimits.MAX_USERS,
|
||||
'MAX_LOCATIONS': QuotaLimits.MAX_LOCATIONS,
|
||||
'MAX_PRODUCTS': QuotaLimits.MAX_PRODUCTS,
|
||||
'MAX_RECIPES': QuotaLimits.MAX_RECIPES,
|
||||
'MAX_SUPPLIERS': QuotaLimits.MAX_SUPPLIERS,
|
||||
'TRAINING_JOBS_PER_DAY': QuotaLimits.TRAINING_JOBS_PER_DAY,
|
||||
'FORECAST_GENERATION_PER_DAY': QuotaLimits.FORECAST_GENERATION_PER_DAY,
|
||||
'DATASET_SIZE_ROWS': QuotaLimits.DATASET_SIZE_ROWS,
|
||||
'FORECAST_HORIZON_DAYS': QuotaLimits.FORECAST_HORIZON_DAYS,
|
||||
'HISTORICAL_DATA_ACCESS_DAYS': QuotaLimits.HISTORICAL_DATA_ACCESS_DAYS,
|
||||
'BULK_IMPORT_ROWS': QuotaLimits.BULK_IMPORT_ROWS,
|
||||
'BULK_EXPORT_ROWS': QuotaLimits.BULK_EXPORT_ROWS,
|
||||
'POS_SYNC_INTERVAL_MINUTES': QuotaLimits.POS_SYNC_INTERVAL_MINUTES,
|
||||
'API_CALLS_PER_HOUR': QuotaLimits.API_CALLS_PER_HOUR,
|
||||
'WEBHOOK_ENDPOINTS': QuotaLimits.WEBHOOK_ENDPOINTS,
|
||||
'FILE_STORAGE_GB': QuotaLimits.FILE_STORAGE_GB,
|
||||
'REPORT_RETENTION_DAYS': QuotaLimits.REPORT_RETENTION_DAYS,
|
||||
}
|
||||
|
||||
quotas = quota_map.get(quota_type, {})
|
||||
return quotas.get(tier_enum)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FEATURE ACCESS CONFIGURATION
|
||||
# ============================================================================
|
||||
|
||||
class PlanFeatures:
|
||||
"""
|
||||
Feature availability by subscription tier
|
||||
Each tier includes all features from lower tiers
|
||||
"""
|
||||
|
||||
# ===== Core Features (All Tiers) =====
|
||||
CORE_FEATURES = [
|
||||
'inventory_management',
|
||||
'sales_tracking',
|
||||
'basic_recipes',
|
||||
'production_planning',
|
||||
'basic_reporting',
|
||||
'mobile_app_access',
|
||||
'email_support',
|
||||
'easy_step_by_step_onboarding', # NEW: Value-add onboarding
|
||||
]
|
||||
|
||||
# ===== Starter Tier Features =====
|
||||
STARTER_FEATURES = CORE_FEATURES + [
|
||||
'basic_forecasting',
|
||||
'demand_prediction',
|
||||
'waste_tracking',
|
||||
'order_management',
|
||||
'customer_management',
|
||||
'supplier_management',
|
||||
'batch_tracking',
|
||||
'expiry_alerts',
|
||||
]
|
||||
|
||||
# ===== Professional Tier Features =====
|
||||
PROFESSIONAL_FEATURES = STARTER_FEATURES + [
|
||||
# Advanced Analytics
|
||||
'advanced_analytics',
|
||||
'custom_reports',
|
||||
'sales_analytics',
|
||||
'supplier_performance',
|
||||
'waste_analysis',
|
||||
'profitability_analysis',
|
||||
|
||||
# External Data Integration
|
||||
'weather_data_integration',
|
||||
'traffic_data_integration',
|
||||
|
||||
# Multi-location
|
||||
'multi_location_support',
|
||||
'location_comparison',
|
||||
'inventory_transfer',
|
||||
|
||||
# Advanced Forecasting
|
||||
'batch_scaling',
|
||||
'recipe_feasibility_check',
|
||||
'seasonal_patterns',
|
||||
'longer_forecast_horizon',
|
||||
|
||||
# Integration
|
||||
'pos_integration',
|
||||
'accounting_export',
|
||||
'basic_api_access',
|
||||
|
||||
# Support
|
||||
'priority_email_support',
|
||||
'phone_support',
|
||||
]
|
||||
|
||||
# ===== Enterprise Tier Features =====
|
||||
ENTERPRISE_FEATURES = PROFESSIONAL_FEATURES + [
|
||||
# Advanced ML & AI
|
||||
'scenario_modeling',
|
||||
'what_if_analysis',
|
||||
'risk_assessment',
|
||||
'advanced_ml_parameters',
|
||||
'model_artifacts_access',
|
||||
'custom_algorithms',
|
||||
|
||||
# Advanced Integration
|
||||
'full_api_access',
|
||||
'unlimited_webhooks',
|
||||
'erp_integration',
|
||||
'custom_integrations',
|
||||
|
||||
# Enterprise Features
|
||||
'multi_tenant_management',
|
||||
'white_label_option',
|
||||
'custom_branding',
|
||||
'sso_saml',
|
||||
'advanced_permissions',
|
||||
'audit_logs_export',
|
||||
'compliance_reports',
|
||||
|
||||
# Advanced Analytics
|
||||
'benchmarking',
|
||||
'competitive_analysis',
|
||||
'market_insights',
|
||||
'predictive_maintenance',
|
||||
|
||||
# Premium Support
|
||||
'dedicated_account_manager',
|
||||
'priority_support',
|
||||
'24_7_support',
|
||||
'custom_training',
|
||||
'onsite_support', # Optional add-on
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_features(tier: str) -> List[str]:
|
||||
"""Get all features for a tier"""
|
||||
tier_enum = SubscriptionTier(tier.lower())
|
||||
|
||||
feature_map = {
|
||||
SubscriptionTier.STARTER: PlanFeatures.STARTER_FEATURES,
|
||||
SubscriptionTier.PROFESSIONAL: PlanFeatures.PROFESSIONAL_FEATURES,
|
||||
SubscriptionTier.ENTERPRISE: PlanFeatures.ENTERPRISE_FEATURES,
|
||||
}
|
||||
|
||||
return feature_map.get(tier_enum, PlanFeatures.CORE_FEATURES)
|
||||
|
||||
@staticmethod
|
||||
def has_feature(tier: str, feature: str) -> bool:
|
||||
"""Check if a tier has access to a feature"""
|
||||
features = PlanFeatures.get_features(tier)
|
||||
return feature in features
|
||||
|
||||
@staticmethod
|
||||
def requires_professional_tier(feature: str) -> bool:
|
||||
"""Check if feature requires Professional+ tier"""
|
||||
return (
|
||||
feature not in PlanFeatures.STARTER_FEATURES and
|
||||
feature in PlanFeatures.PROFESSIONAL_FEATURES
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def requires_enterprise_tier(feature: str) -> bool:
|
||||
"""Check if feature requires Enterprise tier"""
|
||||
return (
|
||||
feature not in PlanFeatures.PROFESSIONAL_FEATURES and
|
||||
feature in PlanFeatures.ENTERPRISE_FEATURES
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SUBSCRIPTION PLAN METADATA
|
||||
# ============================================================================
|
||||
|
||||
class SubscriptionPlanMetadata:
|
||||
"""Complete metadata for each subscription plan"""
|
||||
|
||||
PLANS = {
|
||||
SubscriptionTier.STARTER: {
|
||||
"name": "Starter",
|
||||
"description": "Perfect for small bakeries getting started",
|
||||
"tagline": "Essential tools for small operations",
|
||||
"popular": False,
|
||||
"monthly_price": PlanPricing.MONTHLY_PRICES[SubscriptionTier.STARTER],
|
||||
"yearly_price": PlanPricing.YEARLY_PRICES[SubscriptionTier.STARTER],
|
||||
"trial_days": 14,
|
||||
"features": PlanFeatures.STARTER_FEATURES,
|
||||
"limits": {
|
||||
"users": QuotaLimits.MAX_USERS[SubscriptionTier.STARTER],
|
||||
"locations": QuotaLimits.MAX_LOCATIONS[SubscriptionTier.STARTER],
|
||||
"products": QuotaLimits.MAX_PRODUCTS[SubscriptionTier.STARTER],
|
||||
"forecasts_per_day": QuotaLimits.FORECAST_GENERATION_PER_DAY[SubscriptionTier.STARTER],
|
||||
},
|
||||
"support": "Email support (48h response)",
|
||||
"recommended_for": "Single location, up to 5 team members",
|
||||
},
|
||||
SubscriptionTier.PROFESSIONAL: {
|
||||
"name": "Professional",
|
||||
"description": "For growing bakeries with multiple locations",
|
||||
"tagline": "Advanced features & analytics",
|
||||
"popular": True, # Most popular plan
|
||||
"monthly_price": PlanPricing.MONTHLY_PRICES[SubscriptionTier.PROFESSIONAL],
|
||||
"yearly_price": PlanPricing.YEARLY_PRICES[SubscriptionTier.PROFESSIONAL],
|
||||
"trial_days": 14,
|
||||
"features": PlanFeatures.PROFESSIONAL_FEATURES,
|
||||
"limits": {
|
||||
"users": QuotaLimits.MAX_USERS[SubscriptionTier.PROFESSIONAL],
|
||||
"locations": QuotaLimits.MAX_LOCATIONS[SubscriptionTier.PROFESSIONAL],
|
||||
"products": QuotaLimits.MAX_PRODUCTS[SubscriptionTier.PROFESSIONAL],
|
||||
"forecasts_per_day": QuotaLimits.FORECAST_GENERATION_PER_DAY[SubscriptionTier.PROFESSIONAL],
|
||||
},
|
||||
"support": "Priority email + phone support (24h response)",
|
||||
"recommended_for": "Multi-location operations, up to 20 team members",
|
||||
},
|
||||
SubscriptionTier.ENTERPRISE: {
|
||||
"name": "Enterprise",
|
||||
"description": "For large bakery chains and franchises",
|
||||
"tagline": "Unlimited scale & custom solutions",
|
||||
"popular": False,
|
||||
"monthly_price": PlanPricing.MONTHLY_PRICES[SubscriptionTier.ENTERPRISE],
|
||||
"yearly_price": PlanPricing.YEARLY_PRICES[SubscriptionTier.ENTERPRISE],
|
||||
"trial_days": 30,
|
||||
"features": PlanFeatures.ENTERPRISE_FEATURES,
|
||||
"limits": {
|
||||
"users": "Unlimited",
|
||||
"locations": "Unlimited",
|
||||
"products": "Unlimited",
|
||||
"forecasts_per_day": "Unlimited",
|
||||
},
|
||||
"support": "24/7 dedicated support + account manager",
|
||||
"recommended_for": "Enterprise operations, unlimited scale",
|
||||
"custom_pricing": True,
|
||||
"contact_sales": True,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_plan_info(tier: str) -> Dict[str, Any]:
|
||||
"""Get complete plan information"""
|
||||
tier_enum = SubscriptionTier(tier.lower())
|
||||
return SubscriptionPlanMetadata.PLANS.get(tier_enum, {})
|
||||
|
||||
@staticmethod
|
||||
def get_all_plans() -> Dict[SubscriptionTier, Dict[str, Any]]:
|
||||
"""Get information for all plans"""
|
||||
return SubscriptionPlanMetadata.PLANS
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HELPER FUNCTIONS
|
||||
# ============================================================================
|
||||
|
||||
def get_training_job_quota(tier: str) -> Optional[int]:
|
||||
"""Get training job daily quota for tier"""
|
||||
return QuotaLimits.get_limit('TRAINING_JOBS_PER_DAY', tier)
|
||||
|
||||
|
||||
def get_forecast_quota(tier: str) -> Optional[int]:
|
||||
"""Get forecast generation daily quota for tier"""
|
||||
return QuotaLimits.get_limit('FORECAST_GENERATION_PER_DAY', tier)
|
||||
|
||||
|
||||
def get_dataset_size_limit(tier: str) -> Optional[int]:
|
||||
"""Get dataset size limit for tier"""
|
||||
return QuotaLimits.get_limit('DATASET_SIZE_ROWS', tier)
|
||||
|
||||
|
||||
def get_forecast_horizon_limit(tier: str) -> int:
|
||||
"""Get forecast horizon limit for tier"""
|
||||
return QuotaLimits.get_limit('FORECAST_HORIZON_DAYS', tier) or 7
|
||||
|
||||
|
||||
def get_historical_data_limit(tier: str) -> Optional[int]:
|
||||
"""Get historical data access limit for tier"""
|
||||
return QuotaLimits.get_limit('HISTORICAL_DATA_ACCESS_DAYS', tier)
|
||||
|
||||
|
||||
def can_access_feature(tier: str, feature: str) -> bool:
|
||||
"""Check if tier can access a feature"""
|
||||
return PlanFeatures.has_feature(tier, feature)
|
||||
|
||||
|
||||
def get_tier_comparison() -> Dict[str, Any]:
|
||||
"""
|
||||
Get feature comparison across all tiers
|
||||
Useful for pricing pages
|
||||
"""
|
||||
return {
|
||||
"tiers": ["starter", "professional", "enterprise"],
|
||||
"features": {
|
||||
"core": PlanFeatures.CORE_FEATURES,
|
||||
"starter_only": list(set(PlanFeatures.STARTER_FEATURES) - set(PlanFeatures.CORE_FEATURES)),
|
||||
"professional_only": list(set(PlanFeatures.PROFESSIONAL_FEATURES) - set(PlanFeatures.STARTER_FEATURES)),
|
||||
"enterprise_only": list(set(PlanFeatures.ENTERPRISE_FEATURES) - set(PlanFeatures.PROFESSIONAL_FEATURES)),
|
||||
},
|
||||
"pricing": {
|
||||
tier.value: {
|
||||
"monthly": float(PlanPricing.MONTHLY_PRICES[tier]),
|
||||
"yearly": float(PlanPricing.YEARLY_PRICES[tier]),
|
||||
}
|
||||
for tier in SubscriptionTier
|
||||
},
|
||||
}
|
||||
Reference in New Issue
Block a user