Initial commit - production deployment

This commit is contained in:
2026-01-21 17:17:16 +01:00
commit c23d00dd92
2289 changed files with 638440 additions and 0 deletions

0
gateway/app/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,46 @@
# ================================================================
# GATEWAY SERVICE CONFIGURATION
# gateway/app/core/config.py
# ================================================================
"""
Gateway service configuration
Central API Gateway for all microservices
"""
from shared.config.base import BaseServiceSettings
import os
from typing import Dict, List
class GatewaySettings(BaseServiceSettings):
"""Gateway-specific settings"""
# Service Identity
APP_NAME: str = "Bakery Forecasting Gateway"
SERVICE_NAME: str = "gateway"
DESCRIPTION: str = "API Gateway for Bakery Forecasting Platform"
# Gateway-specific Redis database
REDIS_DB: int = 6
# Service Discovery
CONSUL_URL: str = os.getenv("CONSUL_URL", "http://consul:8500")
ENABLE_SERVICE_DISCOVERY: bool = os.getenv("ENABLE_SERVICE_DISCOVERY", "false").lower() == "true"
# Load Balancing
ENABLE_LOAD_BALANCING: bool = os.getenv("ENABLE_LOAD_BALANCING", "true").lower() == "true"
LOAD_BALANCER_ALGORITHM: str = os.getenv("LOAD_BALANCER_ALGORITHM", "round_robin")
# Circuit Breaker
CIRCUIT_BREAKER_ENABLED: bool = os.getenv("CIRCUIT_BREAKER_ENABLED", "true").lower() == "true"
CIRCUIT_BREAKER_FAILURE_THRESHOLD: int = int(os.getenv("CIRCUIT_BREAKER_FAILURE_THRESHOLD", "5"))
CIRCUIT_BREAKER_RECOVERY_TIMEOUT: int = int(os.getenv("CIRCUIT_BREAKER_RECOVERY_TIMEOUT", "60"))
# Request/Response Settings
MAX_REQUEST_SIZE: int = int(os.getenv("MAX_REQUEST_SIZE", "10485760")) # 10MB
REQUEST_TIMEOUT: int = int(os.getenv("REQUEST_TIMEOUT", "30"))
# Gateway doesn't need a database
DATABASE_URL: str = ""
settings = GatewaySettings()

View File

@@ -0,0 +1,346 @@
"""
Unified Header Management System for API Gateway
Centralized header injection, forwarding, and validation
"""
import structlog
from fastapi import Request
from typing import Dict, Any, Optional, List
logger = structlog.get_logger()
class HeaderManager:
"""
Centralized header management for consistent header handling across gateway
"""
# Standard header names (lowercase for consistency)
STANDARD_HEADERS = {
'user_id': 'x-user-id',
'user_email': 'x-user-email',
'user_role': 'x-user-role',
'user_type': 'x-user-type',
'service_name': 'x-service-name',
'tenant_id': 'x-tenant-id',
'subscription_tier': 'x-subscription-tier',
'subscription_status': 'x-subscription-status',
'is_demo': 'x-is-demo',
'demo_session_id': 'x-demo-session-id',
'demo_account_type': 'x-demo-account-type',
'tenant_access_type': 'x-tenant-access-type',
'can_view_children': 'x-can-view-children',
'parent_tenant_id': 'x-parent-tenant-id',
'forwarded_by': 'x-forwarded-by',
'request_id': 'x-request-id'
}
# Headers that should be sanitized/removed from incoming requests
SANITIZED_HEADERS = [
'x-subscription-',
'x-user-',
'x-tenant-',
'x-demo-',
'x-forwarded-by'
]
# Headers that should be forwarded to downstream services
FORWARDABLE_HEADERS = [
'authorization',
'content-type',
'accept',
'accept-language',
'user-agent',
'x-internal-service', # Required for internal service-to-service ML/alert triggers
'stripe-signature' # Required for Stripe webhook signature verification
]
def __init__(self):
self._initialized = False
def initialize(self):
"""Initialize header manager"""
if not self._initialized:
logger.info("HeaderManager initialized")
self._initialized = True
def sanitize_incoming_headers(self, request: Request) -> None:
"""
Remove sensitive headers from incoming request to prevent spoofing
"""
if not hasattr(request.headers, '_list'):
return
# Filter out headers that start with sanitized prefixes
sanitized_headers = [
(k, v) for k, v in request.headers.raw
if not any(k.decode().lower().startswith(prefix.lower())
for prefix in self.SANITIZED_HEADERS)
]
request.headers.__dict__["_list"] = sanitized_headers
logger.debug("Sanitized incoming headers")
def inject_context_headers(self, request: Request, user_context: Dict[str, Any],
tenant_id: Optional[str] = None) -> Dict[str, str]:
"""
Inject standardized context headers into request
Returns dict of injected headers for reference
"""
injected_headers = {}
# Ensure headers list exists
if not hasattr(request.headers, '_list'):
request.headers.__dict__["_list"] = []
# Store headers in request.state for cross-middleware access
request.state.injected_headers = {}
# User context headers
if user_context.get('user_id'):
header_name = self.STANDARD_HEADERS['user_id']
header_value = str(user_context['user_id'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('email'):
header_name = self.STANDARD_HEADERS['user_email']
header_value = str(user_context['email'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('role'):
header_name = self.STANDARD_HEADERS['user_role']
header_value = str(user_context['role'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# User type (service vs regular user)
if user_context.get('type'):
header_name = self.STANDARD_HEADERS['user_type']
header_value = str(user_context['type'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Service name for service tokens
if user_context.get('service'):
header_name = self.STANDARD_HEADERS['service_name']
header_value = str(user_context['service'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Tenant context
if tenant_id:
header_name = self.STANDARD_HEADERS['tenant_id']
header_value = str(tenant_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Subscription context
if user_context.get('subscription_tier'):
header_name = self.STANDARD_HEADERS['subscription_tier']
header_value = str(user_context['subscription_tier'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('subscription_status'):
header_name = self.STANDARD_HEADERS['subscription_status']
header_value = str(user_context['subscription_status'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Demo session context
is_demo = user_context.get('is_demo', False)
if is_demo:
header_name = self.STANDARD_HEADERS['is_demo']
header_value = "true"
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('demo_session_id'):
header_name = self.STANDARD_HEADERS['demo_session_id']
header_value = str(user_context['demo_session_id'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
if user_context.get('demo_account_type'):
header_name = self.STANDARD_HEADERS['demo_account_type']
header_value = str(user_context['demo_account_type'])
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Hierarchical access context
if tenant_id:
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
can_view_children = getattr(request.state, 'can_view_children', False)
header_name = self.STANDARD_HEADERS['tenant_access_type']
header_value = str(tenant_access_type)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
header_name = self.STANDARD_HEADERS['can_view_children']
header_value = str(can_view_children).lower()
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Parent tenant ID if hierarchical access
parent_tenant_id = getattr(request.state, 'parent_tenant_id', None)
if parent_tenant_id:
header_name = self.STANDARD_HEADERS['parent_tenant_id']
header_value = str(parent_tenant_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Gateway identification
header_name = self.STANDARD_HEADERS['forwarded_by']
header_value = "bakery-gateway"
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
# Request ID if available
request_id = getattr(request.state, 'request_id', None)
if request_id:
header_name = self.STANDARD_HEADERS['request_id']
header_value = str(request_id)
self._add_header(request, header_name, header_value)
injected_headers[header_name] = header_value
request.state.injected_headers[header_name] = header_value
logger.info("🔧 Injected context headers",
user_id=user_context.get('user_id'),
user_type=user_context.get('type', ''),
service_name=user_context.get('service', ''),
role=user_context.get('role', ''),
tenant_id=tenant_id,
is_demo=is_demo,
demo_session_id=user_context.get('demo_session_id', ''),
path=request.url.path)
return injected_headers
def _add_header(self, request: Request, header_name: str, header_value: str) -> None:
"""
Safely add header to request
"""
try:
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
except Exception as e:
logger.warning(f"Failed to add header {header_name}: {e}")
def get_forwardable_headers(self, request: Request) -> Dict[str, str]:
"""
Get headers that should be forwarded to downstream services
Includes both original request headers and injected context headers
"""
forwardable_headers = {}
# Add forwardable original headers
for header_name in self.FORWARDABLE_HEADERS:
header_value = request.headers.get(header_name)
if header_value:
forwardable_headers[header_name] = header_value
# Add injected context headers from request.state
if hasattr(request.state, 'injected_headers'):
for header_name, header_value in request.state.injected_headers.items():
forwardable_headers[header_name] = header_value
# Add authorization header if present
auth_header = request.headers.get('authorization')
if auth_header:
forwardable_headers['authorization'] = auth_header
return forwardable_headers
def get_all_headers_for_proxy(self, request: Request,
additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
"""
Get complete set of headers for proxying to downstream services
"""
headers = self.get_forwardable_headers(request)
# Add any additional headers
if additional_headers:
headers.update(additional_headers)
# Remove host header as it will be set by httpx
headers.pop('host', None)
return headers
def validate_required_headers(self, request: Request, required_headers: List[str]) -> bool:
"""
Validate that required headers are present
"""
missing_headers = []
for header_name in required_headers:
# Check in injected headers first
if hasattr(request.state, 'injected_headers'):
if header_name in request.state.injected_headers:
continue
# Check in request headers
if request.headers.get(header_name):
continue
missing_headers.append(header_name)
if missing_headers:
logger.warning(f"Missing required headers: {missing_headers}")
return False
return True
def get_header_value(self, request: Request, header_name: str,
default: Optional[str] = None) -> Optional[str]:
"""
Get header value from either injected headers or request headers
"""
# Check injected headers first
if hasattr(request.state, 'injected_headers'):
if header_name in request.state.injected_headers:
return request.state.injected_headers[header_name]
# Check request headers
return request.headers.get(header_name, default)
def add_header_for_middleware(self, request: Request, header_name: str, header_value: str) -> None:
"""
Allow middleware to add headers to the unified header system
This ensures all headers are available for proxying
"""
# Ensure injected_headers exists
if not hasattr(request.state, 'injected_headers'):
request.state.injected_headers = {}
# Add header to injected_headers
request.state.injected_headers[header_name] = header_value
# Also add to actual request headers for compatibility
try:
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
except Exception as e:
logger.warning(f"Failed to add header {header_name} to request headers: {e}")
logger.debug(f"Middleware added header: {header_name} = {header_value}")
# Global instance for easy access
header_manager = HeaderManager()

View File

@@ -0,0 +1,65 @@
"""
Service discovery for API Gateway
"""
import logging
import httpx
from typing import Optional, Dict
import json
from app.core.config import settings
logger = logging.getLogger(__name__)
class ServiceDiscovery:
"""Service discovery client"""
def __init__(self):
self.consul_url = settings.CONSUL_URL if hasattr(settings, 'CONSUL_URL') else None
self.service_cache: Dict[str, str] = {}
async def get_service_url(self, service_name: str) -> Optional[str]:
"""Get service URL from service discovery"""
# Return cached URL if available
if service_name in self.service_cache:
return self.service_cache[service_name]
# Try Consul if enabled
if self.consul_url and getattr(settings, 'ENABLE_SERVICE_DISCOVERY', False):
try:
url = await self._get_from_consul(service_name)
if url:
self.service_cache[service_name] = url
return url
except Exception as e:
logger.warning(f"Failed to get {service_name} from Consul: {e}")
# Fall back to environment variables
return self._get_from_env(service_name)
async def _get_from_consul(self, service_name: str) -> Optional[str]:
"""Get service URL from Consul"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(
f"{self.consul_url}/v1/health/service/{service_name}?passing=true"
)
if response.status_code == 200:
services = response.json()
if services:
service = services[0]
address = service['Service']['Address']
port = service['Service']['Port']
return f"http://{address}:{port}"
except Exception as e:
logger.error(f"Consul query failed: {e}")
return None
def _get_from_env(self, service_name: str) -> Optional[str]:
"""Get service URL from environment variables"""
env_var = f"{service_name.upper().replace('-', '_')}_SERVICE_URL"
return getattr(settings, env_var, None)

512
gateway/app/main.py Normal file
View File

@@ -0,0 +1,512 @@
"""
API Gateway - Central entry point for all microservices
Handles routing, authentication, rate limiting, and cross-cutting concerns
"""
import asyncio
import json
import structlog
import resource
import os
import time
from fastapi import Request, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.responses import StreamingResponse
import httpx
from shared.redis_utils import initialize_redis, close_redis, get_redis_client
from shared.service_base import StandardFastAPIService
from app.core.config import settings
from app.core.header_manager import header_manager
from app.middleware.request_id import RequestIDMiddleware
from app.middleware.auth import AuthMiddleware
from app.middleware.logging import LoggingMiddleware
from app.middleware.rate_limit import RateLimitMiddleware
from app.middleware.rate_limiting import APIRateLimitMiddleware
from app.middleware.subscription import SubscriptionMiddleware
from app.middleware.demo_middleware import DemoMiddleware
from app.middleware.read_only_mode import ReadOnlyModeMiddleware
from app.routes import auth, tenant, registration, nominatim, subscription, demo, pos, geocoding, poi_context, webhooks, telemetry
# Initialize logger
logger = structlog.get_logger()
# Check file descriptor limits
try:
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
if soft_limit < 1024:
logger.warning(f"Low file descriptor limit detected: {soft_limit}")
else:
logger.info(f"File descriptor limit: {soft_limit} (sufficient)")
except Exception as e:
logger.debug(f"Could not check file descriptor limits: {e}")
# Global Redis client for SSE streaming
redis_client = None
class GatewayService(StandardFastAPIService):
"""Gateway Service with standardized monitoring setup"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Initialize HeaderManager early
header_manager.initialize()
logger.info("HeaderManager initialized")
# Initialize Redis during service creation so it's available when needed
try:
# We need to run the async initialization in a sync context
import asyncio
try:
# Check if there's already a running event loop
loop = asyncio.get_running_loop()
# If there is, we'll initialize Redis later in on_startup
self.redis_initialized = False
self.redis_client = None
except RuntimeError:
# No event loop running, safe to run the async function
import asyncio
import nest_asyncio
nest_asyncio.apply() # Allow nested event loops
async def init_redis():
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
return await get_redis_client()
self.redis_client = asyncio.run(init_redis())
self.redis_initialized = True
logger.info("Connected to Redis for SSE streaming")
except Exception as e:
logger.error(f"Failed to initialize Redis during service creation: {e}")
self.redis_initialized = False
self.redis_client = None
async def on_startup(self, app):
"""Custom startup logic for Gateway"""
global redis_client
# Initialize Redis if not already done during service creation
if not self.redis_initialized:
try:
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
self.redis_client = await get_redis_client()
redis_client = self.redis_client # Update global variable
self.redis_initialized = True
logger.info("Connected to Redis for SSE streaming")
except Exception as e:
logger.error(f"Failed to connect to Redis during startup: {e}")
# Register custom metrics for gateway-specific operations
if self.telemetry_providers and self.telemetry_providers.app_metrics:
logger.info("Gateway-specific metrics tracking enabled")
await super().on_startup(app)
async def on_shutdown(self, app):
"""Custom shutdown logic for Gateway"""
await super().on_shutdown(app)
# Close Redis
await close_redis()
logger.info("Redis connection closed")
# Create service instance
service = GatewayService(
service_name="gateway",
app_name="Bakery Forecasting API Gateway",
description="Central API Gateway for bakery forecasting microservices",
version="1.0.0",
log_level=getattr(settings, 'LOG_LEVEL', 'INFO'),
cors_origins=settings.CORS_ORIGINS_LIST,
enable_metrics=True,
enable_health_checks=True,
enable_tracing=True,
enable_cors=True
)
# Create FastAPI app
app = service.create_app()
# Add API rate limiting middleware with Redis client - this needs to be done after app creation
# but before other middleware that might depend on it
# Wait for Redis to be initialized if not already done
if not hasattr(service, 'redis_client') or not service.redis_client:
# Wait briefly for Redis initialization to complete
import time
time.sleep(1)
# Check again after allowing time for initialization
if hasattr(service, 'redis_client') and service.redis_client:
app.add_middleware(APIRateLimitMiddleware, redis_client=service.redis_client)
logger.info("API rate limiting middleware enabled")
else:
logger.warning("Redis client not available for API rate limiting middleware")
else:
app.add_middleware(APIRateLimitMiddleware, redis_client=service.redis_client)
logger.info("API rate limiting middleware enabled")
# Add gateway-specific middleware (in REVERSE order of execution)
# Execution order: RequestIDMiddleware -> DemoMiddleware -> AuthMiddleware -> ReadOnlyModeMiddleware -> SubscriptionMiddleware -> APIRateLimitMiddleware -> RateLimitMiddleware -> LoggingMiddleware
app.add_middleware(LoggingMiddleware)
app.add_middleware(RateLimitMiddleware, calls_per_minute=300)
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL)
app.add_middleware(ReadOnlyModeMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL)
app.add_middleware(AuthMiddleware)
app.add_middleware(DemoMiddleware)
app.add_middleware(RequestIDMiddleware)
# Include routers
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
app.include_router(registration.router, prefix="/api/v1", tags=["registration"])
app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"])
app.include_router(subscription.router, prefix="/api/v1", tags=["subscriptions"])
# Notification routes are now handled by tenant router at /api/v1/tenants/{tenant_id}/notifications/*
app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"])
app.include_router(geocoding.router, prefix="/api/v1/geocoding", tags=["geocoding"])
app.include_router(pos.router, prefix="/api/v1/pos", tags=["pos"])
app.include_router(demo.router, prefix="/api/v1", tags=["demo"])
# Include webhooks at the root level to handle /api/v1/webhooks/*
# Webhook routes are defined with full /api/v1/webhooks/* paths for consistency
app.include_router(webhooks.router, prefix="", tags=["webhooks"])
# Include telemetry routes for frontend OpenTelemetry data
app.include_router(telemetry.router, prefix="/api/v1", tags=["telemetry"])
# ================================================================
# SERVER-SENT EVENTS (SSE) HELPER FUNCTIONS
# ================================================================
def _get_subscription_channels(tenant_id: str, channel_filters: list) -> list:
"""Determine which Redis channels to subscribe to based on filters"""
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
all_classes = ["alerts", "notifications"]
channels = []
if not channel_filters:
# Subscribe to ALL channels (backward compatible)
for domain in all_domains:
for event_class in all_classes:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
channels.append(f"tenant:{tenant_id}:recommendations")
channels.append(f"alerts:{tenant_id}") # Legacy
return channels
# Parse filters and expand wildcards
for filter_pattern in channel_filters:
if filter_pattern == "*.*":
for domain in all_domains:
for event_class in all_classes:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
channels.append(f"tenant:{tenant_id}:recommendations")
elif filter_pattern.endswith(".*"):
domain = filter_pattern.split(".")[0]
for event_class in all_classes:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
elif filter_pattern.startswith("*."):
event_class = filter_pattern.split(".")[1]
if event_class == "recommendations":
channels.append(f"tenant:{tenant_id}:recommendations")
else:
for domain in all_domains:
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
elif filter_pattern == "recommendations":
channels.append(f"tenant:{tenant_id}:recommendations")
else:
channels.append(f"tenant:{tenant_id}:{filter_pattern}")
return list(set(channels))
async def _load_initial_state(redis_client, tenant_id: str, channel_filters: list) -> list:
"""Load initial state from Redis cache based on channel filters"""
initial_events = []
try:
if not channel_filters:
# Legacy cache
legacy_cache_key = f"active_alerts:{tenant_id}"
cached_data = await redis_client.get(legacy_cache_key)
if cached_data:
return json.loads(cached_data)
# New domain-specific caches
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
all_classes = ["alerts", "notifications"]
for domain in all_domains:
for event_class in all_classes:
cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
# Recommendations
recommendations_cache_key = f"active_events:{tenant_id}:recommendations"
cached_data = await redis_client.get(recommendations_cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
return initial_events
# Load based on specific filters
for filter_pattern in channel_filters:
if "." in filter_pattern:
parts = filter_pattern.split(".")
domain = parts[0] if parts[0] != "*" else None
event_class = parts[1] if len(parts) > 1 and parts[1] != "*" else None
if domain and event_class:
cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
elif domain and not event_class:
for ec in ["alerts", "notifications"]:
cache_key = f"active_events:{tenant_id}:{domain}.{ec}"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
elif not domain and event_class:
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
for d in all_domains:
cache_key = f"active_events:{tenant_id}:{d}.{event_class}s"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
elif filter_pattern == "recommendations":
cache_key = f"active_events:{tenant_id}:recommendations"
cached_data = await redis_client.get(cache_key)
if cached_data:
initial_events.extend(json.loads(cached_data))
return initial_events
except Exception as e:
logger.error(f"Error loading initial state for tenant {tenant_id}: {e}")
return []
def _determine_event_type(event_data: dict) -> str:
"""Determine SSE event type from event data"""
if 'event_class' in event_data:
return event_data['event_class']
if 'item_type' in event_data:
if event_data['item_type'] == 'recommendation':
return 'recommendation'
else:
return 'alert'
return 'alert'
# ================================================================
# SERVER-SENT EVENTS (SSE) ENDPOINT
# ================================================================
@app.get("/api/v1/events")
async def events_stream(
request: Request,
tenant_id: str,
channels: str = None
):
"""
Server-Sent Events stream for real-time notifications with multi-channel support.
Query Parameters:
tenant_id: Tenant identifier (required)
channels: Comma-separated channel filters (optional)
"""
global redis_client
if not redis_client:
raise HTTPException(status_code=503, detail="SSE service unavailable")
# Extract user context from request state
user_context = request.state.user
user_id = user_context.get('user_id')
email = user_context.get('email')
if not tenant_id:
raise HTTPException(status_code=400, detail="tenant_id query parameter is required")
# Parse channel filters
channel_filters = []
if channels:
channel_filters = [c.strip() for c in channels.split(',') if c.strip()]
logger.info(f"SSE connection request for user {email}, tenant {tenant_id}, channels: {channel_filters or 'all'}")
async def event_generator():
"""Generate server-sent events from Redis pub/sub"""
pubsub = None
try:
pubsub = redis_client.pubsub()
logger.debug(f"Created Redis pubsub connection for tenant: {tenant_id}")
# Determine channels
subscription_channels = _get_subscription_channels(tenant_id, channel_filters)
# Subscribe
if subscription_channels:
await pubsub.subscribe(*subscription_channels)
logger.info(f"Subscribed to {len(subscription_channels)} channels for tenant {tenant_id}")
else:
legacy_channel = f"alerts:{tenant_id}"
await pubsub.subscribe(legacy_channel)
# Connection event
yield f"event: connection\n"
yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'channels': subscription_channels or ['all'], 'timestamp': time.time()})}\n\n"
# Initial state
initial_events = await _load_initial_state(redis_client, tenant_id, channel_filters)
if initial_events:
logger.info(f"Sending {len(initial_events)} initial events to tenant {tenant_id}")
yield f"event: initial_state\n"
yield f"data: {json.dumps(initial_events)}\n\n"
heartbeat_counter = 0
while True:
if await request.is_disconnected():
logger.info(f"SSE client disconnected for tenant: {tenant_id}")
break
try:
message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0)
if message and message['type'] == 'message':
event_data = json.loads(message['data'])
event_type = _determine_event_type(event_data)
event_data['_channel'] = message['channel'].decode('utf-8') if isinstance(message['channel'], bytes) else message['channel']
yield f"event: {event_type}\n"
yield f"data: {json.dumps(event_data)}\n\n"
except asyncio.TimeoutError:
heartbeat_counter += 1
if heartbeat_counter >= 10:
yield f"event: heartbeat\n"
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n"
heartbeat_counter = 0
except asyncio.CancelledError:
logger.info(f"SSE connection cancelled for tenant: {tenant_id}")
except Exception as e:
logger.error(f"SSE error for tenant {tenant_id}: {e}", exc_info=True)
finally:
if pubsub:
try:
await pubsub.unsubscribe()
await pubsub.close()
except Exception as e:
logger.error(f"Error closing pubsub: {e}")
logger.info(f"SSE connection closed for tenant: {tenant_id}")
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control",
}
)
# ================================================================
# WEBSOCKET ROUTING FOR TRAINING SERVICE
# ================================================================
@app.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
"""WebSocket proxy with token verification for training service"""
token = websocket.query_params.get("token")
if not token:
await websocket.accept()
await websocket.close(code=1008, reason="Authentication token required")
return
# Verify token
from shared.auth.jwt_handler import JWTHandler
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
try:
payload = jwt_handler.verify_token(token)
if not payload or not payload.get('user_id'):
await websocket.accept()
await websocket.close(code=1008, reason="Invalid token")
return
except Exception as e:
await websocket.accept()
await websocket.close(code=1008, reason="Token verification failed")
return
await websocket.accept()
# Build WebSocket URL to training service
training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/')
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
training_ws = None
try:
import websockets
from websockets.protocol import State
training_ws = await websockets.connect(
training_ws_url,
ping_interval=120,
ping_timeout=60,
close_timeout=60,
open_timeout=30
)
async def forward_frontend_to_training():
try:
while training_ws and training_ws.state == State.OPEN:
data = await websocket.receive()
if data.get("type") == "websocket.receive":
if "text" in data:
await training_ws.send(data["text"])
elif "bytes" in data:
await training_ws.send(data["bytes"])
elif data.get("type") == "websocket.disconnect":
break
except Exception:
pass
async def forward_training_to_frontend():
try:
while training_ws and training_ws.state == State.OPEN:
message = await training_ws.recv()
await websocket.send_text(message)
except Exception:
pass
await asyncio.gather(
forward_frontend_to_training(),
forward_training_to_frontend(),
return_exceptions=True
)
except Exception as e:
logger.error("WebSocket proxy error", job_id=job_id, error=str(e))
finally:
if training_ws and training_ws.state == State.OPEN:
try:
await training_ws.close()
except:
pass
try:
if not websocket.client_state.name == 'DISCONNECTED':
await websocket.close(code=1000, reason="Proxy closed")
except:
pass
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

View File

@@ -0,0 +1,649 @@
# gateway/app/middleware/auth.py
"""
Enhanced Authentication Middleware for API Gateway with Tenant Access Control
FIXED VERSION - Proper JWT verification and token structure handling
"""
import structlog
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from typing import Optional, Dict, Any
import httpx
import json
from app.core.config import settings
from app.core.header_manager import header_manager
from shared.auth.jwt_handler import JWTHandler
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
logger = structlog.get_logger()
# JWT handler for local token validation - using SAME configuration as auth service
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
# Routes that don't require authentication
PUBLIC_ROUTES = [
"/health",
"/metrics",
"/docs",
"/redoc",
"/openapi.json",
"/api/v1/auth/login",
"/api/v1/auth/register",
"/api/v1/auth/refresh",
"/api/v1/auth/verify",
"/api/v1/auth/start-registration", # Registration step 1 - SetupIntent creation
"/api/v1/auth/complete-registration", # Registration step 2 - Completion after 3DS
"/api/v1/registration/payment-setup", # New registration payment setup endpoint
"/api/v1/registration/complete", # New registration completion endpoint
"/api/v1/registration/state/", # Registration state check
"/api/v1/auth/verify-email", # Email verification
"/api/v1/auth/password/reset-request", # Password reset request - no auth required
"/api/v1/auth/password/reset", # Password reset with token - no auth required
"/api/v1/nominatim/search",
"/api/v1/plans",
"/api/v1/demo/accounts",
"/api/v1/demo/sessions",
"/api/v1/webhooks/stripe", # Stripe webhook endpoint - bypasses auth for signature verification
"/api/v1/webhooks/generic", # Generic webhook endpoint
"/api/v1/telemetry/v1/traces", # Frontend telemetry traces - no auth for performance
"/api/v1/telemetry/v1/metrics", # Frontend telemetry metrics - no auth for performance
"/api/v1/telemetry/health" # Telemetry health check
]
# Routes accessible with demo session (no JWT required, just demo session header)
DEMO_ACCESSIBLE_ROUTES = [
"/api/v1/tenants/", # All tenant endpoints accessible in demo mode
]
class AuthMiddleware(BaseHTTPMiddleware):
"""
Enhanced Authentication Middleware with Tenant Access Control
"""
def __init__(self, app, redis_client=None):
super().__init__(app)
self.redis_client = redis_client # For caching and rate limiting
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with enhanced authentication and tenant access control"""
# Skip authentication for OPTIONS requests (CORS preflight)
if request.method == "OPTIONS":
return await call_next(request)
# SECURITY: Remove any incoming sensitive headers using HeaderManager
header_manager.sanitize_incoming_headers(request)
# Skip authentication for public routes
if self._is_public_route(request.url.path):
return await call_next(request)
# ✅ Check if demo middleware already set user context OR check query param for SSE
demo_session_header = request.headers.get("X-Demo-Session-Id")
demo_session_query = request.query_params.get("demo_session_id") # For SSE endpoint
logger.info(f"Auth check - path: {request.url.path}, demo_header: {demo_session_header}, demo_query: {demo_session_query}, has_demo_state: {hasattr(request.state, 'is_demo_session')}")
# For SSE endpoint with demo_session_id in query params, validate it here
if request.url.path == "/api/v1/events" and demo_session_query and not hasattr(request.state, "is_demo_session"):
logger.info(f"SSE endpoint with demo_session_id query param: {demo_session_query}")
# Validate demo session via demo-session service using JWT service token
import httpx
try:
# Create service token for gateway-to-demo-session communication
service_token = jwt_handler.create_service_token(service_name="gateway")
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://demo-session-service:8000/api/v1/demo/sessions/{demo_session_query}",
headers={"Authorization": f"Bearer {service_token}"}
)
if response.status_code == 200:
session_data = response.json()
# Set demo session context
request.state.is_demo_session = True
request.state.user = {
"user_id": f"demo-user-{demo_session_query}",
"email": f"demo-{demo_session_query}@demo.local",
"tenant_id": session_data.get("virtual_tenant_id"),
"demo_session_id": demo_session_query,
}
request.state.tenant_id = session_data.get("virtual_tenant_id")
logger.info(f"✅ Demo session validated for SSE: {demo_session_query}")
else:
logger.warning(f"Invalid demo session for SSE: {demo_session_query}")
return JSONResponse(
status_code=401,
content={"detail": "Invalid demo session"}
)
except Exception as e:
logger.error(f"Failed to validate demo session for SSE: {e}")
return JSONResponse(
status_code=503,
content={"detail": "Demo session service unavailable"}
)
if hasattr(request.state, "is_demo_session") and request.state.is_demo_session:
if hasattr(request.state, "user") and request.state.user:
logger.info(f"✅ Demo session authenticated for route: {request.url.path}")
# Demo middleware already validated and set user context
# But we still need to inject context headers for downstream services
user_context = request.state.user
tenant_id = user_context.get("tenant_id") or getattr(request.state, "tenant_id", None)
# For demo sessions, get the actual subscription tier from the tenant service
# instead of always defaulting to enterprise
if not user_context.get("subscription_tier"):
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
if subscription_tier:
user_context["subscription_tier"] = subscription_tier
else:
# Fallback to enterprise for demo if no tier is found
user_context["subscription_tier"] = "enterprise"
logger.debug(f"Demo session subscription tier set to {user_context['subscription_tier']}", tenant_id=tenant_id)
await self._inject_context_headers(request, user_context, tenant_id)
return await call_next(request)
# ✅ STEP 1: Extract and validate JWT token
token = self._extract_token(request)
if not token:
logger.warning(f"❌ Missing token for protected route: {request.url.path}, demo_header: {demo_session_header}")
return JSONResponse(
status_code=401,
content={"detail": "Authentication required"}
)
# ✅ STEP 2: Verify token and get user context
user_context = await self._verify_token(token, request)
if not user_context:
logger.warning(f"Invalid token for route: {request.url.path}")
return JSONResponse(
status_code=401,
content={"detail": "User not authenticated"}
)
# ✅ STEP 3: Extract tenant context from URL using shared utility
tenant_id = extract_tenant_id_from_path(request.url.path)
# ✅ STEP 4: Verify tenant access if this is a tenant-scoped route
if tenant_id and is_tenant_scoped_path(request.url.path):
# Skip tenant access verification for service tokens (services have admin access)
if user_context.get("type") != "service":
# Use TenantAccessManager for gateway-level verification with caching
if self.redis_client and tenant_access_manager.redis_client is None:
tenant_access_manager.redis_client = self.redis_client
has_access = await tenant_access_manager.verify_basic_tenant_access(
user_context["user_id"],
tenant_id
)
if not has_access:
logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}")
return JSONResponse(
status_code=403,
content={"detail": f"Access denied to tenant {tenant_id}"}
)
else:
logger.debug(f"Service token granted access to tenant {tenant_id}",
service=user_context.get("service"))
# Get tenant subscription tier and inject into user context
# NEW: Use JWT data if available, skip HTTP call
if user_context.get("subscription_from_jwt"):
subscription_tier = user_context.get("subscription_tier")
logger.debug("Using subscription tier from JWT", tier=subscription_tier)
else:
# Only for old tokens - remove after full rollout
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
if subscription_tier:
user_context["subscription_tier"] = subscription_tier
# Check hierarchical access to determine access type and permissions
hierarchical_access = await tenant_access_manager.verify_hierarchical_access(
user_context["user_id"],
tenant_id
)
# Set tenant context in request state
request.state.tenant_id = tenant_id
request.state.tenant_verified = True
request.state.tenant_access_type = hierarchical_access.get("access_type", "direct")
request.state.can_view_children = hierarchical_access.get("can_view_children", False)
logger.debug(f"Tenant access verified",
user_id=user_context["user_id"],
tenant_id=tenant_id,
subscription_tier=subscription_tier,
access_type=hierarchical_access.get("access_type"),
can_view_children=hierarchical_access.get("can_view_children"),
path=request.url.path)
# ✅ STEP 5: Inject user context into request
request.state.user = user_context
request.state.authenticated = True
# ✅ STEP 6: Add context headers for downstream services
await self._inject_context_headers(request, user_context, tenant_id)
logger.debug(f"Authenticated request",
user_email=user_context['email'],
tenant_id=tenant_id,
path=request.url.path)
# Process the request
response = await call_next(request)
# Add token expiry warning header if token is near expiry
if hasattr(request.state, 'token_near_expiry') and request.state.token_near_expiry:
response.headers["X-Token-Refresh-Suggested"] = "true"
return response
def _is_public_route(self, path: str) -> bool:
"""Check if route requires authentication"""
return any(path.startswith(route) for route in PUBLIC_ROUTES)
def _extract_token(self, request: Request) -> Optional[str]:
"""
Extract JWT token from Authorization header or query params for SSE.
For SSE endpoints (/api/v1/events), browsers' EventSource API cannot send
custom headers, so we must accept token as query parameter.
For all other routes, token must be in Authorization header (more secure).
Security note: Query param tokens are logged. Use short expiry and filter logs.
"""
# SSE endpoint exception: token in query param (EventSource API limitation)
if request.url.path == "/api/v1/events":
token = request.query_params.get("token")
if token:
logger.debug("Token extracted from query param for SSE endpoint")
return token
logger.warning("SSE request missing token in query param")
return None
# Standard authentication: Authorization header for all other routes
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.split(" ")[1]
return None
async def _verify_token(self, token: str, request: Request = None) -> Optional[Dict[str, Any]]:
"""
Verify JWT token with improved fallback strategy
FIXED: Better error handling and token structure validation
"""
# Strategy 1: Try local JWT validation first (fast)
try:
payload = jwt_handler.verify_token(token)
if payload and self._validate_token_payload(payload):
logger.debug("Token validated locally")
# NEW: Check token freshness for subscription changes (async)
if payload.get("tenant_id") and request:
try:
is_fresh = await self._verify_token_freshness(payload, payload["tenant_id"])
if not is_fresh:
logger.warning("Stale token detected - subscription changed since token was issued",
user_id=payload.get("user_id"),
tenant_id=payload.get("tenant_id"))
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token is stale - subscription has changed"
)
except Exception as e:
logger.warning("Token freshness check failed, allowing token", error=str(e))
# Allow token if check fails (fail open for availability)
# Check if token is near expiry and set flag for response header
if request:
import time
exp_time = payload.get("exp", 0)
current_time = time.time()
time_until_expiry = exp_time - current_time
if time_until_expiry < 300: # 5 minutes
request.state.token_near_expiry = True
# Convert JWT payload to user context format
return self._jwt_payload_to_user_context(payload)
except Exception as e:
logger.debug(f"Local token validation failed: {e}")
# Strategy 2: Check cache for recently validated tokens
if self.redis_client:
try:
cached_user = await self._get_cached_user(token)
if cached_user:
logger.debug("Token found in cache")
return cached_user
except Exception as e:
logger.warning(f"Cache lookup failed: {e}")
# Strategy 3: Verify with auth service (authoritative)
try:
user_context = await self._verify_with_auth_service(token)
if user_context:
# Cache successful validation
if self.redis_client:
await self._cache_user(token, user_context)
logger.debug("Token validated by auth service")
return user_context
except Exception as e:
logger.error(f"Auth service validation failed: {e}")
return None
def _validate_token_payload(self, payload: Dict[str, Any]) -> bool:
"""
Validate JWT payload has required fields
FIXED: Updated to match actual token structure from auth service
"""
required_fields = ["user_id", "email", "exp", "type"]
missing_fields = [field for field in required_fields if field not in payload]
if missing_fields:
logger.warning(f"Token payload missing fields: {missing_fields}")
return False
# Validate token type
token_type = payload.get("type")
if token_type not in ["access", "service"]:
logger.warning(f"Invalid token type: {payload.get('type')}")
return False
# Check if token is near expiry (within 5 minutes) and log warning
import time
exp_time = payload.get("exp", 0)
current_time = time.time()
time_until_expiry = exp_time - current_time
if time_until_expiry < 300: # 5 minutes
logger.warning(f"Token expires in {int(time_until_expiry)} seconds for user {payload.get('email')}")
# NEW: Check token freshness for subscription changes
if payload.get("tenant_id"):
try:
# Note: We can't await here because this is a sync function
# Token freshness will be checked in the async dispatch method
# For now, just log that we would check freshness
logger.debug("Token freshness check would be performed in async context",
tenant_id=payload.get("tenant_id"))
except Exception as e:
logger.warning("Token freshness check setup failed", error=str(e))
# FIX: Validate service tokens with tenant context for tenant-scoped routes
if token_type == "service" and payload.get("tenant_id"):
# Service tokens with tenant context are valid for tenant-scoped operations
logger.debug("Service token with tenant context validated",
service=payload.get("service"), tenant_id=payload.get("tenant_id"))
return True
def _validate_jwt_integrity(self, payload: Dict[str, Any]) -> bool:
"""
Validate JWT payload integrity beyond signature verification.
Prevents edge cases where payload might be malformed.
"""
# Required fields must exist
required_fields = ["user_id", "email", "exp", "iat", "iss"]
if not all(field in payload for field in required_fields):
logger.warning("JWT missing required fields", missing=[f for f in required_fields if f not in payload])
return False
# Issuer must be our auth service
if payload.get("iss") != "bakery-auth":
logger.warning("JWT has invalid issuer", issuer=payload.get("iss"))
return False
# Token type must be valid
if payload.get("type") not in ["access", "service"]:
logger.warning("JWT has invalid type", token_type=payload.get("type"))
return False
# Subscription tier must be valid if present
valid_tiers = ["starter", "professional", "enterprise"]
if payload.get("subscription"):
tier = payload["subscription"].get("tier", "").lower()
if tier and tier not in valid_tiers:
logger.warning("JWT has invalid subscription tier", tier=tier)
return False
return True
async def _verify_token_freshness(self, payload: Dict[str, Any], tenant_id: str) -> bool:
"""
Verify token was issued after the last subscription change.
Prevents use of stale tokens with old subscription data.
"""
if not self.redis_client:
return True # Skip check if no Redis
try:
subscription_changed_at = await self.redis_client.get(
f"tenant:{tenant_id}:subscription_changed_at"
)
if subscription_changed_at:
changed_timestamp = float(subscription_changed_at)
token_issued_at = payload.get("iat", 0)
if token_issued_at < changed_timestamp:
logger.warning(
"Token issued before subscription change",
token_iat=token_issued_at,
subscription_changed=changed_timestamp,
tenant_id=tenant_id
)
return False # Token is stale
except Exception as e:
logger.warning("Failed to check token freshness", error=str(e))
return True
def _jwt_payload_to_user_context(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert JWT payload to user context format
FIXED: Proper mapping between JWT structure and user context
"""
# NEW: Validate JWT integrity before processing
if not self._validate_jwt_integrity(payload):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid JWT payload"
)
base_context = {
"user_id": payload["user_id"],
"email": payload["email"],
"exp": payload["exp"],
"valid": True,
"role": payload.get("role", "user"),
}
# NEW: Extract subscription from JWT
if payload.get("tenant_id"):
base_context["tenant_id"] = payload["tenant_id"]
base_context["tenant_role"] = payload.get("tenant_role", "member")
if payload.get("subscription"):
sub = payload["subscription"]
base_context["subscription_tier"] = sub.get("tier", "starter")
base_context["subscription_status"] = sub.get("status", "active")
base_context["subscription_from_jwt"] = True # Flag to skip HTTP
if payload.get("tenant_access"):
base_context["tenant_access"] = payload["tenant_access"]
if payload.get("service"):
service_name = payload["service"]
base_context["service"] = service_name
base_context["type"] = "service"
base_context["role"] = "admin"
base_context["user_id"] = f"{service_name}-service"
base_context["email"] = f"{service_name}-service@internal"
# FIX: Service tokens with tenant context should use that tenant_id
if payload.get("tenant_id"):
base_context["tenant_id"] = payload["tenant_id"]
logger.debug(f"Service authentication with tenant context: {service_name}, tenant_id: {payload['tenant_id']}")
else:
logger.debug(f"Service authentication: {service_name}")
return base_context
async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]:
"""
Verify token with auth service
FIXED: Improved error handling and response parsing
"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.post(
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
headers={"Authorization": f"Bearer {token}"}
)
if response.status_code == 200:
auth_response = response.json()
# Validate auth service response structure
if auth_response.get("valid") and auth_response.get("user_id"):
return {
"user_id": auth_response["user_id"],
"email": auth_response["email"],
"exp": auth_response.get("exp"),
"valid": True
}
else:
logger.warning(f"Auth service returned invalid response: {auth_response}")
return None
else:
logger.warning(f"Auth service returned {response.status_code}: {response.text}")
return None
except httpx.TimeoutException:
logger.error("Auth service timeout during token verification")
return None
except Exception as e:
logger.error(f"Auth service error: {e}")
return None
async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]:
"""
Get user context from cache
FIXED: Better error handling and JSON parsing
"""
if not self.redis_client:
return None
cache_key = f"auth:token:{hash(token) % 1000000}" # Use modulo for shorter keys
try:
cached_data = await self.redis_client.get(cache_key)
if cached_data:
if isinstance(cached_data, bytes):
cached_data = cached_data.decode()
return json.loads(cached_data)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse cached user data: {e}")
except Exception as e:
logger.warning(f"Cache lookup error: {e}")
return None
async def _cache_user(self, token: str, user_context: Dict[str, Any]) -> None:
"""
Cache user context
FIXED: Better error handling and expiration
"""
if not self.redis_client:
return
cache_key = f"auth:token:{hash(token) % 1000000}"
try:
# Cache for 5 minutes (shorter than token expiry)
await self.redis_client.setex(
cache_key,
300, # 5 minutes
json.dumps(user_context)
)
except Exception as e:
logger.warning(f"Failed to cache user context: {e}")
async def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
"""
Inject user and tenant context headers for downstream services using unified HeaderManager
"""
# Use unified HeaderManager for consistent header injection
injected_headers = header_manager.inject_context_headers(request, user_context, tenant_id)
# Add hierarchical access headers if tenant context exists
if tenant_id:
# If this is hierarchical access, include parent tenant ID
# Get parent tenant ID from the auth service if available
try:
import httpx
async with httpx.AsyncClient(timeout=3.0) as client:
response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/hierarchy",
headers={"Authorization": request.headers.get("Authorization", "")}
)
if response.status_code == 200:
hierarchy_data = response.json()
parent_tenant_id = hierarchy_data.get("parent_tenant_id")
if parent_tenant_id:
# Add parent tenant ID using HeaderManager for consistency
header_name = header_manager.STANDARD_HEADERS['parent_tenant_id']
header_value = str(parent_tenant_id)
header_manager.add_header_for_middleware(request, header_name, header_value)
logger.info(f"Added parent tenant ID header: {parent_tenant_id}")
except Exception as e:
logger.warning(f"Failed to get parent tenant ID: {e}")
pass
return injected_headers
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
"""
Get tenant subscription tier using fast cached endpoint
Args:
tenant_id: Tenant ID
request: FastAPI request for headers
Returns:
Subscription tier string or None
"""
try:
# Use fast cached subscription tier endpoint (has its own Redis caching)
async with httpx.AsyncClient(timeout=3.0) as client:
headers = {"Authorization": request.headers.get("Authorization", "")}
response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
headers=headers
)
if response.status_code == 200:
tier_data = response.json()
subscription_tier = tier_data.get("tier", "starter")
logger.debug("Subscription tier from cached endpoint",
tenant_id=tenant_id,
tier=subscription_tier,
cached=tier_data.get("cached", False))
return subscription_tier
else:
logger.warning(f"Failed to get tenant subscription tier: {response.status_code}")
return "starter" # Default to starter
except Exception as e:
logger.error(f"Error getting tenant subscription tier: {e}")
return "starter" # Default to starter on error

View File

@@ -0,0 +1,384 @@
"""
Demo Session Middleware
Handles demo account restrictions and virtual tenant injection
"""
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from typing import Optional
import uuid
import httpx
import structlog
import json
logger = structlog.get_logger()
# Fixed Demo Tenant IDs (these are the template tenants that will be cloned)
# Professional demo (merged from San Pablo + La Espiga)
DEMO_TENANT_PROFESSIONAL = uuid.UUID("a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6")
# Enterprise chain demo (parent + 3 children)
DEMO_TENANT_ENTERPRISE_CHAIN = uuid.UUID("c3d4e5f6-a7b8-49c0-d1e2-f3a4b5c6d7e8")
DEMO_TENANT_CHILD_1 = uuid.UUID("d4e5f6a7-b8c9-40d1-e2f3-a4b5c6d7e8f9")
DEMO_TENANT_CHILD_2 = uuid.UUID("e5f6a7b8-c9d0-41e2-f3a4-b5c6d7e8f9a0")
DEMO_TENANT_CHILD_3 = uuid.UUID("f6a7b8c9-d0e1-42f3-a4b5-c6d7e8f9a0b1")
# Demo tenant IDs (base templates)
DEMO_TENANT_IDS = {
str(DEMO_TENANT_PROFESSIONAL), # Professional demo tenant
str(DEMO_TENANT_ENTERPRISE_CHAIN), # Enterprise chain parent
str(DEMO_TENANT_CHILD_1), # Enterprise chain child 1
str(DEMO_TENANT_CHILD_2), # Enterprise chain child 2
str(DEMO_TENANT_CHILD_3), # Enterprise chain child 3
}
# Demo user IDs - Maps demo account type to actual user UUIDs from fixture files
# These IDs are the owner IDs from the respective 01-tenant.json files
DEMO_USER_IDS = {
"professional": "c1a2b3c4-d5e6-47a8-b9c0-d1e2f3a4b5c6", # María García López (professional/01-tenant.json -> owner.id)
"enterprise": "d2e3f4a5-b6c7-48d9-e0f1-a2b3c4d5e6f7" # Director (enterprise/parent/01-tenant.json -> owner.id)
}
# Allowed operations for demo accounts (limited write)
DEMO_ALLOWED_OPERATIONS = {
# Read operations - all allowed
"GET": ["*"],
# Limited write operations for realistic testing
"POST": [
"/api/v1/pos/sales",
"/api/v1/pos/sessions",
"/api/v1/orders",
"/api/v1/inventory/adjustments",
"/api/v1/sales",
"/api/v1/production/batches",
"/api/v1/tenants/batch/sales-summary",
"/api/v1/tenants/batch/production-summary",
"/api/v1/auth/me/onboarding/complete", # Allow completing onboarding (no-op for demos)
"/api/v1/tenants/*/notifications/send", # Allow notifications (ML insights, alerts, etc.)
# Note: Forecast generation is explicitly blocked (see DEMO_BLOCKED_PATHS)
],
"PUT": [
"/api/v1/pos/sales/*",
"/api/v1/orders/*",
"/api/v1/inventory/stock/*",
"/api/v1/auth/me/onboarding/step", # Allow onboarding step updates (no-op for demos)
],
# Blocked operations
"DELETE": [], # No deletes allowed
"PATCH": [], # No patches allowed
}
# Explicitly blocked paths for demo accounts (even if method would be allowed)
# These require trained AI models which demo tenants don't have
DEMO_BLOCKED_PATHS = [
"/api/v1/forecasts/single",
"/api/v1/forecasts/multi-day",
"/api/v1/forecasts/batch",
]
DEMO_BLOCKED_PATH_MESSAGE = {
"forecasts": {
"message": "La generación de pronósticos no está disponible para cuentas demo. "
"Las cuentas demo no tienen modelos de IA entrenados.",
"message_en": "Forecast generation is not available for demo accounts. "
"Demo accounts do not have trained AI models.",
}
}
class DemoMiddleware(BaseHTTPMiddleware):
"""Middleware to handle demo session logic with Redis caching"""
def __init__(self, app, demo_session_url: str = "http://demo-session-service:8000"):
super().__init__(app)
self.demo_session_url = demo_session_url
self._redis_client = None
async def _get_redis_client(self):
"""Get or lazily initialize Redis client"""
if self._redis_client is None:
try:
from shared.redis_utils import get_redis_client
self._redis_client = await get_redis_client()
logger.debug("Demo middleware: Redis client initialized")
except Exception as e:
logger.warning(f"Demo middleware: Failed to get Redis client: {e}. Caching disabled.")
self._redis_client = False # Sentinel value to avoid retrying
return self._redis_client if self._redis_client is not False else None
async def _get_cached_session(self, session_id: str) -> Optional[dict]:
"""Get session info from Redis cache"""
try:
redis_client = await self._get_redis_client()
if not redis_client:
return None
cache_key = f"demo_session:{session_id}"
cached_data = await redis_client.get(cache_key)
if cached_data:
logger.debug("Demo middleware: Cache HIT", session_id=session_id)
return json.loads(cached_data)
else:
logger.debug("Demo middleware: Cache MISS", session_id=session_id)
return None
except Exception as e:
logger.warning(f"Demo middleware: Redis cache read error: {e}")
return None
async def _cache_session(self, session_id: str, session_info: dict, ttl: int = 30):
"""Cache session info in Redis with TTL"""
try:
redis_client = await self._get_redis_client()
if not redis_client:
return
cache_key = f"demo_session:{session_id}"
serialized = json.dumps(session_info)
await redis_client.setex(cache_key, ttl, serialized)
logger.debug(f"Demo middleware: Cached session {session_id} (TTL: {ttl}s)")
except Exception as e:
logger.warning(f"Demo middleware: Redis cache write error: {e}")
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request through demo middleware"""
# Skip demo middleware for demo service endpoints
demo_service_paths = [
"/api/v1/demo/accounts",
"/api/v1/demo/sessions",
"/api/v1/demo/stats",
"/api/v1/demo/operations",
]
if any(request.url.path.startswith(path) or request.url.path == path for path in demo_service_paths):
return await call_next(request)
# Extract session ID from header or cookie
session_id = (
request.headers.get("X-Demo-Session-Id") or
request.cookies.get("demo_session_id")
)
logger.info(f"🎭 DemoMiddleware - path: {request.url.path}, session_id: {session_id}")
# Extract tenant ID from request
tenant_id = request.headers.get("X-Tenant-Id")
# Check if this is a demo session request
if session_id:
try:
# PERFORMANCE OPTIMIZATION: Check Redis cache first before HTTP call
session_info = await self._get_cached_session(session_id)
if not session_info:
# Cache miss - fetch from demo service
logger.debug("Demo middleware: Fetching from demo service", session_id=session_id)
session_info = await self._get_session_info(session_id)
# Cache the result if successful (30s TTL to balance freshness vs performance)
if session_info:
await self._cache_session(session_id, session_info, ttl=30)
# Accept pending, ready, partial, failed (if data exists), and active (deprecated) statuses
# Even "failed" sessions can be usable if some services succeeded
valid_statuses = ["pending", "ready", "partial", "failed", "active"]
current_status = session_info.get("status") if session_info else None
if session_info and current_status in valid_statuses:
# NOTE: Path transformation for demo-user removed.
# Frontend now receives the real demo_user_id from session creation
# and uses it directly in API calls.
# Inject virtual tenant ID
# Use scope state directly to avoid potential state property issues
request.scope.setdefault("state", {})
state = request.scope["state"]
state["tenant_id"] = session_info["virtual_tenant_id"]
state["is_demo_session"] = True
state["demo_account_type"] = session_info["demo_account_type"]
state["demo_session_status"] = current_status # Track status for monitoring
# Inject demo user context for auth middleware
# Uses DEMO_USER_IDS constant defined at module level
demo_user_id = DEMO_USER_IDS.get(
session_info.get("demo_account_type", "professional"),
DEMO_USER_IDS["professional"]
)
# This allows the request to pass through AuthMiddleware
# NEW: Extract subscription tier from demo account type
subscription_tier = "enterprise" if session_info.get("demo_account_type") == "enterprise" else "professional"
state["user"] = {
"user_id": demo_user_id, # Use actual demo user UUID
"email": f"demo-{session_id}@demo.local",
"tenant_id": session_info["virtual_tenant_id"],
"role": "owner", # Demo users have owner role
"is_demo": True,
"demo_session_id": session_id,
"demo_account_type": session_info.get("demo_account_type", "professional"),
"demo_session_status": current_status,
# NEW: Subscription context (no HTTP call needed!)
"subscription_tier": subscription_tier,
"subscription_status": "active",
"subscription_from_jwt": True # Flag to skip HTTP calls
}
# Update activity
await self._update_session_activity(session_id)
# Check if path is explicitly blocked
blocked_reason = self._check_blocked_path(request.url.path)
if blocked_reason:
return JSONResponse(
status_code=403,
content={
"error": "demo_restriction",
**blocked_reason,
"upgrade_url": "/pricing",
"session_expires_at": session_info.get("expires_at")
}
)
# Check if operation is allowed
if not self._is_operation_allowed(request.method, request.url.path):
return JSONResponse(
status_code=403,
content={
"error": "demo_restriction",
"message": "Esta operación no está permitida en cuentas demo. "
"Las sesiones demo se eliminan automáticamente después de 30 minutos. "
"Suscríbete para obtener acceso completo.",
"message_en": "This operation is not allowed in demo accounts. "
"Demo sessions are automatically deleted after 30 minutes. "
"Subscribe for full access.",
"upgrade_url": "/pricing",
"session_expires_at": session_info.get("expires_at")
}
)
else:
# Session expired, invalid, or in failed/destroyed state
logger.warning(f"Invalid demo session state", session_id=session_id, status=current_status)
return JSONResponse(
status_code=401,
content={
"error": "session_expired",
"message": "Tu sesión demo ha expirado. Crea una nueva sesión para continuar.",
"message_en": "Your demo session has expired. Create a new session to continue.",
"session_status": current_status
}
)
except Exception as e:
logger.error("Demo middleware error", error=str(e), session_id=session_id, path=request.url.path)
# On error, return 401 instead of continuing
return JSONResponse(
status_code=401,
content={
"error": "session_error",
"message": "Error validando sesión demo. Por favor, inténtalo de nuevo.",
"message_en": "Error validating demo session. Please try again."
}
)
# Check if this is a demo tenant (base template)
elif tenant_id in DEMO_TENANT_IDS:
# Direct access to demo tenant without session - block writes
request.scope.setdefault("state", {})
state = request.scope["state"]
state["is_demo_session"] = True
state["tenant_id"] = tenant_id
if request.method not in ["GET", "HEAD", "OPTIONS"]:
return JSONResponse(
status_code=403,
content={
"error": "demo_restriction",
"message": "Acceso directo al tenant demo no permitido. Crea una sesión demo.",
"message_en": "Direct access to demo tenant not allowed. Create a demo session."
}
)
# Proceed with request
response = await call_next(request)
# Add demo session header to response if demo session
if hasattr(request.state, "is_demo_session") and request.state.is_demo_session:
response.headers["X-Demo-Session"] = "true"
return response
async def _get_session_info(self, session_id: str) -> Optional[dict]:
"""Get session information from demo service using JWT service token"""
try:
# Create JWT service token for gateway-to-demo-session communication
from shared.auth.jwt_handler import JWTHandler
from app.core.config import settings
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
service_token = jwt_handler.create_service_token(service_name="gateway")
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(
f"{self.demo_session_url}/api/v1/demo/sessions/{session_id}",
headers={"Authorization": f"Bearer {service_token}"}
)
if response.status_code == 200:
return response.json()
else:
logger.warning("Demo session fetch failed",
session_id=session_id,
status_code=response.status_code,
response_text=response.text[:200] if hasattr(response, 'text') else '')
return None
except Exception as e:
logger.error("Failed to get session info", session_id=session_id, error=str(e))
return None
async def _update_session_activity(self, session_id: str):
"""Update session activity timestamp"""
# Note: Activity tracking is handled by the demo service internally
# No explicit endpoint needed - activity is updated on session access
pass
def _check_blocked_path(self, path: str) -> Optional[dict]:
"""Check if path is explicitly blocked for demo accounts"""
for blocked_path in DEMO_BLOCKED_PATHS:
if blocked_path in path:
# Determine which category of blocked path
if "forecast" in blocked_path:
return DEMO_BLOCKED_PATH_MESSAGE["forecasts"]
# Can add more categories here in the future
return {
"message": "Esta funcionalidad no está disponible para cuentas demo.",
"message_en": "This functionality is not available for demo accounts."
}
return None
def _is_operation_allowed(self, method: str, path: str) -> bool:
"""Check if method + path combination is allowed for demo"""
allowed_paths = DEMO_ALLOWED_OPERATIONS.get(method, [])
# Check for wildcard
if "*" in allowed_paths:
return True
# Check for exact match or pattern match
for allowed_path in allowed_paths:
if allowed_path.endswith("*"):
# Pattern match: /api/orders/* matches /api/orders/123
if path.startswith(allowed_path[:-1]):
return True
elif path == allowed_path:
# Exact match
return True
return False

View File

@@ -0,0 +1,57 @@
"""
Logging middleware for gateway
"""
import logging
import time
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
import uuid
logger = logging.getLogger(__name__)
class LoggingMiddleware(BaseHTTPMiddleware):
"""Logging middleware class"""
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with logging"""
start_time = time.time()
# Generate request ID
request_id = str(uuid.uuid4())
request.state.request_id = request_id
# Log request
logger.info(
f"Request: {request.method} {request.url.path}",
extra={
"method": request.method,
"url": request.url.path,
"query_params": str(request.query_params),
"client_host": request.client.host if request.client else "unknown",
"user_agent": request.headers.get("user-agent", ""),
"request_id": request_id
}
)
# Process request
response = await call_next(request)
# Calculate duration
duration = time.time() - start_time
# Log response
logger.info(
f"Response: {response.status_code} in {duration:.3f}s",
extra={
"status_code": response.status_code,
"duration": duration,
"method": request.method,
"url": request.url.path,
"request_id": request_id
}
)
return response

View File

@@ -0,0 +1,93 @@
"""
Rate limiting middleware for gateway
"""
import logging
import time
from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from typing import Dict, Optional
import asyncio
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware class"""
def __init__(self, app, calls_per_minute: int = 60):
super().__init__(app)
self.calls_per_minute = calls_per_minute
self.requests: Dict[str, list] = {}
self._cleanup_task = None
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with rate limiting"""
# Skip rate limiting for health checks
if request.url.path in ["/health", "/metrics"]:
return await call_next(request)
# Get client identifier
client_id = self._get_client_id(request)
# Check rate limit
if self._is_rate_limited(client_id):
logger.warning(f"Rate limit exceeded for client: {client_id}")
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"}
)
# Record request
self._record_request(client_id)
# Process request
return await call_next(request)
def _get_client_id(self, request: Request) -> str:
"""Get client identifier"""
# Try to get user ID from state (if authenticated)
if hasattr(request.state, 'user') and request.state.user:
return f"user:{request.state.user.get('user_id', 'unknown')}"
# Fall back to IP address
return f"ip:{request.client.host if request.client else 'unknown'}"
def _is_rate_limited(self, client_id: str) -> bool:
"""Check if client is rate limited"""
now = time.time()
minute_ago = now - 60
# Get recent requests for this client
if client_id not in self.requests:
return False
# Filter requests from last minute
recent_requests = [
req_time for req_time in self.requests[client_id]
if req_time > minute_ago
]
# Update the list
self.requests[client_id] = recent_requests
# Check if limit exceeded
return len(recent_requests) >= self.calls_per_minute
def _record_request(self, client_id: str):
"""Record a request for rate limiting"""
now = time.time()
if client_id not in self.requests:
self.requests[client_id] = []
self.requests[client_id].append(now)
# Keep only last minute of requests
minute_ago = now - 60
self.requests[client_id] = [
req_time for req_time in self.requests[client_id]
if req_time > minute_ago
]

View File

@@ -0,0 +1,269 @@
"""
API Rate Limiting Middleware for Gateway
Enforces subscription-based API call quotas per hour
"""
import structlog
import shared.redis_utils
from datetime import datetime, timezone
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Optional
logger = structlog.get_logger()
class APIRateLimitMiddleware(BaseHTTPMiddleware):
"""
Middleware to enforce API rate limits based on subscription tier.
Quota limits per hour:
- Starter: 100 calls/hour
- Professional: 1,000 calls/hour
- Enterprise: 10,000 calls/hour
Uses Redis to track API calls with hourly buckets.
"""
def __init__(self, app, redis_client=None):
super().__init__(app)
self.redis_client = redis_client
async def dispatch(self, request: Request, call_next):
"""
Check API rate limit before processing request.
"""
# Skip rate limiting for certain paths
if self._should_skip_rate_limit(request.url.path):
return await call_next(request)
# Extract tenant_id from request
tenant_id = self._extract_tenant_id(request)
if not tenant_id:
# No tenant ID - skip rate limiting for auth/public endpoints
return await call_next(request)
try:
# Get subscription tier from headers (added by AuthMiddleware)
subscription_tier = request.headers.get("x-subscription-tier")
if not subscription_tier:
# Fallback: get from request state if headers not available
subscription_tier = getattr(request.state, "subscription_tier", None)
if not subscription_tier:
# Final fallback: get from tenant service (should rarely happen)
subscription_tier = await self._get_subscription_tier(tenant_id, request)
logger.warning(f"Subscription tier not found in headers or state, fetched from tenant service: {subscription_tier}")
# Get quota limit for tier
quota_limit = self._get_quota_limit(subscription_tier)
# Check and increment quota
allowed, current_count = await self._check_and_increment_quota(
tenant_id,
quota_limit
)
if not allowed:
logger.warning(
"API rate limit exceeded",
tenant_id=tenant_id,
subscription_tier=subscription_tier,
current_count=current_count,
quota_limit=quota_limit,
path=request.url.path
)
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail={
"error": "rate_limit_exceeded",
"message": f"API rate limit exceeded. Maximum {quota_limit} calls per hour allowed for {subscription_tier} plan.",
"current_count": current_count,
"quota_limit": quota_limit,
"reset_time": self._get_reset_time(),
"upgrade_required": subscription_tier in ['starter', 'professional']
}
)
# Add rate limit headers to response
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(quota_limit)
response.headers["X-RateLimit-Remaining"] = str(max(0, quota_limit - current_count))
response.headers["X-RateLimit-Reset"] = self._get_reset_time()
return response
except HTTPException:
raise
except Exception as e:
logger.error(
"Rate limiting check failed, allowing request",
tenant_id=tenant_id,
error=str(e),
path=request.url.path
)
# Fail open - allow request if rate limiting fails
return await call_next(request)
def _should_skip_rate_limit(self, path: str) -> bool:
"""
Determine if path should skip rate limiting.
"""
skip_paths = [
"/health",
"/metrics",
"/docs",
"/openapi.json",
"/api/v1/auth/",
"/api/v1/plans", # Public pricing info
]
for skip_path in skip_paths:
if path.startswith(skip_path):
return True
return False
def _extract_tenant_id(self, request: Request) -> Optional[str]:
"""
Extract tenant ID from request headers or path.
"""
# Try header first
tenant_id = request.headers.get("x-tenant-id")
if tenant_id:
return tenant_id
# Try to extract from path /api/v1/tenants/{tenant_id}/...
path_parts = request.url.path.split("/")
if "tenants" in path_parts:
try:
tenant_index = path_parts.index("tenants")
if len(path_parts) > tenant_index + 1:
return path_parts[tenant_index + 1]
except (ValueError, IndexError):
pass
return None
async def _get_subscription_tier(self, tenant_id: str, request: Request) -> str:
"""
Get subscription tier from tenant service (with caching).
"""
try:
# Try to get from request state (if subscription middleware already ran)
if hasattr(request.state, "subscription_tier"):
return request.state.subscription_tier
# Call tenant service to get tier
import httpx
from gateway.app.core.config import settings
async with httpx.AsyncClient(timeout=2.0) as client:
response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
headers={
"x-service": "gateway"
}
)
if response.status_code == 200:
data = response.json()
return data.get("tier", "starter")
except Exception as e:
logger.warning(
"Failed to get subscription tier, defaulting to starter",
tenant_id=tenant_id,
error=str(e)
)
return "starter"
def _get_quota_limit(self, subscription_tier: str) -> int:
"""
Get API calls per hour quota for subscription tier.
"""
quota_map = {
"starter": 100,
"professional": 1000,
"enterprise": 10000,
"demo": 1000, # Same as professional
}
return quota_map.get(subscription_tier.lower(), 100)
async def _check_and_increment_quota(
self,
tenant_id: str,
quota_limit: int
) -> tuple[bool, int]:
"""
Check current quota usage and increment counter.
Returns:
(allowed: bool, current_count: int)
"""
if not self.redis_client:
# No Redis - fail open
return True, 0
try:
# Create hourly bucket key
current_hour = datetime.now(timezone.utc).strftime("%Y-%m-%d-%H")
quota_key = f"quota:hourly:api_calls:{tenant_id}:{current_hour}"
# Get current count
current_count = await self.redis_client.get(quota_key)
current_count = int(current_count) if current_count else 0
# Check if within limit
if current_count >= quota_limit:
return False, current_count
# Increment counter
new_count = await self.redis_client.incr(quota_key)
# Set expiry (1 hour + 5 minutes buffer)
await self.redis_client.expire(quota_key, 3900)
return True, new_count
except Exception as e:
logger.error(
"Redis quota check failed",
tenant_id=tenant_id,
error=str(e)
)
# Fail open
return True, 0
def _get_reset_time(self) -> str:
"""
Get the reset time for the current hour bucket (top of next hour).
"""
from datetime import timedelta
now = datetime.now(timezone.utc)
next_hour = (now + timedelta(hours=1)).replace(minute=0, second=0, microsecond=0)
return next_hour.isoformat()
async def get_rate_limit_middleware(app):
"""
Factory function to create rate limiting middleware with Redis client.
"""
try:
from gateway.app.core.config import settings
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
logger.info("API rate limiting middleware initialized with Redis")
return APIRateLimitMiddleware(app, redis_client=redis_client)
except Exception as e:
logger.warning(
"Failed to initialize Redis for rate limiting, middleware will fail open",
error=str(e)
)
return APIRateLimitMiddleware(app, redis_client=None)

View File

@@ -0,0 +1,149 @@
"""
Gateway middleware to enforce read-only mode for subscriptions with status:
- pending_cancellation (until cancellation_effective_date)
- inactive (after cancellation or no active subscription)
Allowed operations in read-only mode:
- GET requests (all read operations)
- POST /api/v1/users/me/delete/request (account deletion)
- POST /api/v1/subscriptions/reactivate (subscription reactivation)
- POST /api/v1/subscriptions/* (subscription management)
"""
import httpx
import logging
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Optional
import re
logger = logging.getLogger(__name__)
# Whitelist of POST/PUT/DELETE endpoints allowed in read-only mode
READ_ONLY_WHITELIST_PATTERNS = [
r'^/api/v1/users/me/delete/request$',
r'^/api/v1/users/me/export.*$',
r'^/api/v1/tenants/.*/subscription/.*', # All tenant subscription endpoints
r'^/api/v1/registration/.*', # Registration flow endpoints
r'^/api/v1/auth/.*', # Allow auth operations
r'^/api/v1/tenants/register$', # Allow new tenant registration (no existing tenant context)
r'^/api/v1/tenants/.*/orchestrator/run-daily-workflow$', # Allow workflow testing
r'^/api/v1/tenants/.*/inventory/ml/insights/.*', # Allow ML insights (safety stock optimization)
r'^/api/v1/tenants/.*/production/ml/insights/.*', # Allow ML insights (yield prediction)
r'^/api/v1/tenants/.*/procurement/ml/insights/.*', # Allow ML insights (supplier analysis, price forecasting)
r'^/api/v1/tenants/.*/forecasting/ml/insights/.*', # Allow ML insights (rules generation)
r'^/api/v1/tenants/.*/forecasting/operations/.*', # Allow forecasting operations
r'^/api/v1/webhooks/.*', # Webhook endpoints - no tenant context
]
class ReadOnlyModeMiddleware(BaseHTTPMiddleware):
"""
Middleware to enforce read-only mode based on subscription status
"""
def __init__(self, app, tenant_service_url: str = "http://tenant-service:8000"):
super().__init__(app)
self.tenant_service_url = tenant_service_url
self.cache = {}
self.cache_ttl = 60
async def check_subscription_status(self, tenant_id: str, authorization: str) -> dict:
"""
Check subscription status from tenant service
Returns subscription data including status and read_only flag
"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(
f"{self.tenant_service_url}/api/v1/tenants/{tenant_id}/subscription/status",
headers={"Authorization": authorization}
)
if response.status_code == 200:
return response.json()
elif response.status_code == 404:
return {"status": "inactive", "is_read_only": True}
else:
logger.warning(
f"Failed to check subscription status: {response.status_code}",
extra={"tenant_id": tenant_id}
)
return {"status": "unknown", "is_read_only": False}
except Exception as e:
logger.error(
f"Error checking subscription status: {e}",
extra={"tenant_id": tenant_id}
)
return {"status": "unknown", "is_read_only": False}
def is_whitelisted_endpoint(self, path: str) -> bool:
"""
Check if endpoint is whitelisted for read-only mode
"""
for pattern in READ_ONLY_WHITELIST_PATTERNS:
if re.match(pattern, path):
return True
return False
def is_write_operation(self, method: str) -> bool:
"""
Determine if HTTP method is a write operation
"""
return method.upper() in ['POST', 'PUT', 'DELETE', 'PATCH']
async def dispatch(self, request: Request, call_next):
"""
Process each request through read-only mode check
"""
tenant_id = request.headers.get("X-Tenant-ID")
authorization = request.headers.get("Authorization")
path = request.url.path
method = request.method
if not tenant_id or not authorization:
return await call_next(request)
if method.upper() == 'GET':
return await call_next(request)
if self.is_whitelisted_endpoint(path):
return await call_next(request)
if self.is_write_operation(method):
subscription_data = await self.check_subscription_status(tenant_id, authorization)
if subscription_data.get("is_read_only", False):
status_detail = subscription_data.get("status", "inactive")
effective_date = subscription_data.get("cancellation_effective_date")
error_message = {
"detail": "Account is in read-only mode",
"reason": f"Subscription status: {status_detail}",
"message": "Your subscription has been cancelled. You can view data but cannot make changes.",
"action_required": "Reactivate your subscription to regain full access",
"reactivation_url": "/app/settings/subscription"
}
if effective_date:
error_message["read_only_until"] = effective_date
error_message["message"] = f"Your subscription is pending cancellation. Read-only mode starts on {effective_date}."
logger.info(
"read_only_mode_enforced",
extra={
"tenant_id": tenant_id,
"path": path,
"method": method,
"subscription_status": status_detail
}
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content=error_message
)
return await call_next(request)

View File

@@ -0,0 +1,83 @@
"""
Request ID Middleware for distributed tracing
Generates and propagates unique request IDs across all services
"""
import uuid
import structlog
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from app.core.header_manager import header_manager
logger = structlog.get_logger()
class RequestIDMiddleware(BaseHTTPMiddleware):
"""
Middleware to generate and propagate request IDs for distributed tracing.
Request IDs are:
- Generated if not provided by client
- Logged with every request
- Propagated to all downstream services
- Returned in response headers
"""
def __init__(self, app):
super().__init__(app)
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with request ID tracking"""
# Extract or generate request ID
request_id = request.headers.get("X-Request-ID")
if not request_id:
request_id = str(uuid.uuid4())
# Store in request state for access by routes
request.state.request_id = request_id
# Bind request ID to structured logger context
logger_ctx = logger.bind(request_id=request_id)
# Inject request ID header for downstream services using HeaderManager
# Note: This runs early in middleware chain, so we use add_header_for_middleware
header_manager.add_header_for_middleware(request, "x-request-id", request_id)
# Log request start
logger_ctx.info(
"Request started",
method=request.method,
path=request.url.path,
client_ip=request.client.host if request.client else None
)
try:
# Process request
response = await call_next(request)
# Add request ID to response headers
response.headers["X-Request-ID"] = request_id
# Log request completion
logger_ctx.info(
"Request completed",
method=request.method,
path=request.url.path,
status_code=response.status_code
)
return response
except Exception as e:
# Log request failure
logger_ctx.error(
"Request failed",
method=request.method,
path=request.url.path,
error=str(e),
error_type=type(e).__name__
)
raise

View File

@@ -0,0 +1,462 @@
"""
Subscription Middleware - Enforces subscription limits and feature access
Updated to support standardized URL structure with tier-based access control
"""
import re
import json
import structlog
from fastapi import Request, Response, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import httpx
from typing import Dict, Any, Optional, List
import asyncio
from datetime import datetime, timezone
from app.core.config import settings
from app.core.header_manager import header_manager
from app.utils.subscription_error_responses import create_upgrade_required_response
logger = structlog.get_logger()
class SubscriptionMiddleware(BaseHTTPMiddleware):
"""
Middleware to enforce subscription-based access control
Supports standardized URL structure:
- Base routes (/api/v1/tenants/{tenant_id}/{service}/{resource}): ALL tiers
- Dashboard routes (/api/v1/tenants/{tenant_id}/{service}/dashboard/*): ALL tiers
- Analytics routes (/api/v1/tenants/{tenant_id}/{service}/analytics/*): PROFESSIONAL+
- Operations routes (/api/v1/tenants/{tenant_id}/{service}/operations/*): ALL tiers (role-based)
"""
def __init__(self, app, tenant_service_url: str, redis_client=None):
super().__init__(app)
self.tenant_service_url = tenant_service_url.rstrip('/')
self.redis_client = redis_client # Optional Redis client for abuse detection
# Define route patterns that require subscription validation
# Using new standardized URL structure
self.protected_routes = {
# ===== ANALYTICS ROUTES - PROFESSIONAL/ENTERPRISE ONLY =====
# Any service analytics endpoint
r'^/api/v1/tenants/[^/]+/[^/]+/analytics/.*': {
'feature': 'analytics',
'minimum_tier': 'professional',
'allowed_tiers': ['professional', 'enterprise'],
'description': 'Analytics features (Professional/Enterprise only)'
},
# ===== TRAINING SERVICE - ALL TIERS =====
r'^/api/v1/tenants/[^/]+/training/.*': {
'feature': 'ml_training',
'minimum_tier': 'basic',
'allowed_tiers': ['basic', 'professional', 'enterprise'],
'description': 'Machine learning model training (Available for all tiers)'
},
# ===== ADVANCED FEATURES - PROFESSIONAL/ENTERPRISE =====
# Advanced reporting and exports
r'^/api/v1/tenants/[^/]+/[^/]+/export/advanced.*': {
'feature': 'advanced_exports',
'minimum_tier': 'professional',
'allowed_tiers': ['professional', 'enterprise'],
'description': 'Advanced export formats (Professional/Enterprise only)'
},
# Bulk operations
r'^/api/v1/tenants/[^/]+/[^/]+/bulk/.*': {
'feature': 'bulk_operations',
'minimum_tier': 'professional',
'allowed_tiers': ['professional', 'enterprise'],
'description': 'Bulk operations (Professional/Enterprise only)'
},
}
# Routes that are explicitly allowed for all tiers (no check needed)
self.public_tier_routes = [
# Base CRUD operations - ALL TIERS
r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/?$',
r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/[^/]+/?$',
# Dashboard routes - ALL TIERS
r'^/api/v1/tenants/[^/]+/[^/]+/dashboard/.*',
# Operations routes - ALL TIERS (role-based control applies)
r'^/api/v1/tenants/[^/]+/[^/]+/operations/.*',
]
async def dispatch(self, request: Request, call_next):
"""Process the request and check subscription requirements"""
# Skip subscription check for certain routes
if self._should_skip_subscription_check(request):
return await call_next(request)
# Check if route is explicitly allowed for all tiers
if self._is_public_tier_route(request.url.path):
return await call_next(request)
# Check if route requires subscription validation
subscription_requirement = self._get_subscription_requirement(request.url.path)
if not subscription_requirement:
return await call_next(request)
# Get tenant ID from request
tenant_id = self._extract_tenant_id(request)
if not tenant_id:
return JSONResponse(
status_code=400,
content={
"error": "subscription_validation_failed",
"message": "Tenant ID required for subscription validation",
"code": "MISSING_TENANT_ID"
}
)
# Validate subscription with new tier-based system
validation_result = await self._validate_subscription_tier(
request,
tenant_id,
subscription_requirement.get('feature'),
subscription_requirement.get('minimum_tier'),
subscription_requirement.get('allowed_tiers', [])
)
if not validation_result['allowed']:
# Use enhanced error response with conversion optimization
feature = subscription_requirement.get('feature')
current_tier = validation_result.get('current_tier', 'unknown')
required_tier = subscription_requirement.get('minimum_tier')
allowed_tiers = subscription_requirement.get('allowed_tiers', [])
# Create conversion-optimized error response
enhanced_response = create_upgrade_required_response(
feature=feature,
current_tier=current_tier,
required_tier=required_tier,
allowed_tiers=allowed_tiers,
custom_message=validation_result.get('message')
)
return JSONResponse(
status_code=enhanced_response.status_code,
content=enhanced_response.dict()
)
# Subscription validation passed, continue with request
response = await call_next(request)
return response
def _is_public_tier_route(self, path: str) -> bool:
"""
Check if route is explicitly allowed for all subscription tiers
Args:
path: Request path
Returns:
True if route is allowed for all tiers
"""
for pattern in self.public_tier_routes:
if re.match(pattern, path):
logger.debug("Route allowed for all tiers", path=path, pattern=pattern)
return True
return False
def _should_skip_subscription_check(self, request: Request) -> bool:
"""Check if subscription validation should be skipped"""
path = request.url.path
method = request.method
# Skip for health checks, auth, and public routes
skip_patterns = [
r'/health.*',
r'/metrics.*',
r'/api/v1/auth/.*',
r'/api/v1/tenants/[^/]+/subscription/.*', # All tenant subscription endpoints
r'/api/v1/registration/.*', # Registration flow endpoints
r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info
r'/api/v1/webhooks/.*', # Webhook endpoints - no tenant context
r'/docs.*',
r'/openapi\.json',
# Training monitoring endpoints (WebSocket and status checks)
r'/api/v1/tenants/[^/]+/training/jobs/.*/live.*', # WebSocket endpoint
r'/api/v1/tenants/[^/]+/training/jobs/.*/status.*', # Status polling endpoint
]
# Skip OPTIONS requests (CORS preflight)
if method == "OPTIONS":
return True
for pattern in skip_patterns:
if re.match(pattern, path):
return True
return False
def _get_subscription_requirement(self, path: str) -> Optional[Dict[str, str]]:
"""Get subscription requirement for a given path"""
for pattern, requirement in self.protected_routes.items():
if re.match(pattern, path):
return requirement
return None
def _extract_tenant_id(self, request: Request) -> Optional[str]:
"""Extract tenant ID from request"""
# Try to get from URL path first
path_match = re.search(r'/api/v1/tenants/([^/]+)/', request.url.path)
if path_match:
return path_match.group(1)
# Try to get from headers
tenant_id = request.headers.get('x-tenant-id')
if tenant_id:
return tenant_id
# Try to get from user state (set by auth middleware)
if hasattr(request.state, 'user') and request.state.user:
return request.state.user.get('tenant_id')
return None
async def _validate_subscription_tier(
self,
request: Request,
tenant_id: str,
feature: Optional[str],
minimum_tier: str,
allowed_tiers: List[str]
) -> Dict[str, Any]:
"""
Validate subscription tier access using cached subscription lookup
Args:
request: FastAPI request
tenant_id: Tenant ID
feature: Feature name (optional, for additional checks)
minimum_tier: Minimum required subscription tier
allowed_tiers: List of allowed subscription tiers
Returns:
Dict with 'allowed' boolean and additional metadata
"""
try:
# Check if JWT already has subscription
if hasattr(request.state, 'user') and request.state.user:
user_context = request.state.user
user_id = user_context.get('user_id', 'unknown')
if user_context.get("subscription_from_jwt"):
# Use JWT data directly - NO HTTP CALL!
current_tier = user_context.get("subscription_tier", "starter")
logger.debug("Using subscription tier from JWT (no HTTP call)",
tenant_id=tenant_id,
current_tier=current_tier,
minimum_tier=minimum_tier,
allowed_tiers=allowed_tiers)
if current_tier not in [tier.lower() for tier in allowed_tiers]:
tier_names = ', '.join(allowed_tiers)
return {
'allowed': False,
'message': f'This feature requires a {tier_names} subscription plan',
'current_tier': current_tier
}
await self._log_subscription_access(
tenant_id,
user_id,
feature,
current_tier,
True,
"jwt"
)
return {
'allowed': True,
'message': 'Access granted (JWT subscription)',
'current_tier': current_tier
}
# Use unified HeaderManager for consistent header handling
headers = header_manager.get_all_headers_for_proxy(request)
# Extract user_id for logging (fallback path)
user_id = header_manager.get_header_value(request, 'x-user-id', 'unknown')
# Call tenant service fast tier endpoint with caching (fallback for old tokens)
timeout_config = httpx.Timeout(
connect=1.0, # Connection timeout - very short for cached endpoint
read=5.0, # Read timeout - short for cached lookup
write=1.0, # Write timeout
pool=1.0 # Pool timeout
)
async with httpx.AsyncClient(timeout=timeout_config) as client:
# Use fast cached tier endpoint (new URL pattern)
tier_response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
headers=headers
)
if tier_response.status_code != 200:
logger.warning(
"Failed to get subscription tier from cache",
tenant_id=tenant_id,
status_code=tier_response.status_code,
response_text=tier_response.text
)
# Fail open for availability
return {
'allowed': True,
'message': 'Access granted (validation service unavailable)',
'current_tier': 'unknown'
}
tier_data = tier_response.json()
current_tier = tier_data.get('tier', 'starter').lower()
logger.debug("Subscription tier check (cached)",
tenant_id=tenant_id,
current_tier=current_tier,
minimum_tier=minimum_tier,
allowed_tiers=allowed_tiers,
cached=tier_data.get('cached', False))
# Check if current tier is in allowed tiers
if current_tier not in [tier.lower() for tier in allowed_tiers]:
tier_names = ', '.join(allowed_tiers)
await self._log_subscription_access(
tenant_id,
user_id,
feature,
current_tier,
False,
"jwt"
)
return {
'allowed': False,
'message': f'This feature requires a {tier_names} subscription plan',
'current_tier': current_tier
}
# Tier check passed
await self._log_subscription_access(
tenant_id,
user_id,
feature,
current_tier,
True,
"database"
)
return {
'allowed': True,
'message': 'Access granted',
'current_tier': current_tier
}
except asyncio.TimeoutError:
logger.error(
"Timeout validating subscription",
tenant_id=tenant_id,
feature=feature
)
# Fail open for availability (let service handle detailed check)
return {
'allowed': True,
'message': 'Access granted (validation timeout)',
'current_plan': 'unknown'
}
except httpx.RequestError as e:
logger.error(
"Request error validating subscription",
tenant_id=tenant_id,
feature=feature,
error=str(e)
)
# Fail open for availability
return {
'allowed': True,
'message': 'Access granted (validation service unavailable)',
'current_plan': 'unknown'
}
except Exception as e:
logger.error(
"Subscription validation error",
tenant_id=tenant_id,
feature=feature,
error=str(e)
)
# Fail open for availability (let service handle detailed check)
return {
'allowed': True,
'message': 'Access granted (validation error)',
'current_plan': 'unknown'
}
async def _log_subscription_access(
self,
tenant_id: str,
user_id: str,
requested_feature: str,
current_tier: str,
access_granted: bool,
source: str # "jwt" or "database"
):
"""
Log all subscription-gated access attempts for audit and anomaly detection.
"""
logger.info(
"Subscription access check",
tenant_id=tenant_id,
user_id=user_id,
feature=requested_feature,
tier=current_tier,
granted=access_granted,
source=source,
timestamp=datetime.now(timezone.utc).isoformat()
)
# For denied access, check for suspicious patterns
if not access_granted:
await self._check_for_abuse_patterns(tenant_id, user_id, requested_feature)
async def _check_for_abuse_patterns(
self,
tenant_id: str,
user_id: str,
feature: str
):
"""
Detect potential abuse patterns like repeated premium feature access attempts.
"""
if not self.redis_client:
return
# Track denied attempts in a sliding window
key = f"subscription_denied:{tenant_id}:{user_id}:{feature}"
try:
attempts = await self.redis_client.incr(key)
if attempts == 1:
await self.redis_client.expire(key, 3600) # 1 hour window
# Alert if too many denied attempts (potential bypass attempt)
if attempts > 10:
logger.warning(
"SECURITY: Excessive premium feature access attempts detected",
tenant_id=tenant_id,
user_id=user_id,
feature=feature,
attempts=attempts,
window="1 hour"
)
# Could trigger alert to security team here
except Exception as e:
logger.warning("Failed to track abuse patterns", error=str(e))

View File

240
gateway/app/routes/auth.py Normal file
View File

@@ -0,0 +1,240 @@
# ================================================================
# gateway/app/routes/auth.py
# ================================================================
"""
Authentication and User Management Routes for API Gateway
Unified proxy to auth microservice
"""
import logging
import httpx
from fastapi import APIRouter, Request, Response, HTTPException, status
from fastapi.responses import JSONResponse
from typing import Dict, Any
from app.core.config import settings
from app.core.header_manager import header_manager
from app.core.service_discovery import ServiceDiscovery
from shared.monitoring.metrics import MetricsCollector
logger = logging.getLogger(__name__)
router = APIRouter()
# Initialize service discovery and metrics
service_discovery = ServiceDiscovery()
metrics = MetricsCollector("gateway")
# Register custom metrics for auth routes
metrics.register_counter("gateway_auth_requests_total", "Total authentication requests through gateway")
metrics.register_counter("gateway_auth_responses_total", "Total authentication responses from gateway")
metrics.register_counter("gateway_auth_errors_total", "Total authentication errors in gateway")
# Auth service configuration
AUTH_SERVICE_URL = settings.AUTH_SERVICE_URL or "http://auth-service:8000"
class AuthProxy:
"""Authentication service proxy with enhanced error handling"""
def __init__(self):
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(30.0),
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20)
)
async def forward_request(
self,
method: str,
path: str,
request: Request
) -> Response:
"""Forward request to auth service with proper error handling"""
# Handle OPTIONS requests directly for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
}
)
try:
# Get auth service URL (with service discovery if available)
auth_url = await self._get_auth_service_url()
target_url = f"{auth_url}/{path}"
# Prepare headers (remove hop-by-hop headers)
# IMPORTANT: Use request.headers directly to get headers added by middleware
# Also check request.state for headers injected by middleware
headers = self._prepare_headers(request.headers, request)
# Get request body
body = await request.body()
# Forward request
logger.info(f"Forwarding {method} /{path} to auth service")
response = await self.client.request(
method=method,
url=target_url,
headers=headers,
content=body,
params=dict(request.query_params)
)
# Record metrics
metrics.increment_counter("gateway_auth_requests_total")
metrics.increment_counter(
"gateway_auth_responses_total",
labels={"status_code": str(response.status_code)}
)
# Prepare response headers
response_headers = self._prepare_response_headers(dict(response.headers))
return Response(
content=response.content,
status_code=response.status_code,
headers=response_headers,
media_type=response.headers.get("content-type")
)
except httpx.TimeoutException:
logger.error(f"Timeout forwarding {method} /{path} to auth service")
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "timeout"})
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail="Authentication service timeout"
)
except httpx.ConnectError:
logger.error(f"Connection error forwarding {method} /{path} to auth service")
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "connection"})
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Authentication service unavailable"
)
except Exception as e:
logger.error(f"Error forwarding {method} /{path} to auth service: {e}")
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "unknown"})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal gateway error"
)
async def _get_auth_service_url(self) -> str:
"""Get auth service URL with service discovery"""
try:
# Try service discovery first
service_url = await service_discovery.get_service_url("auth-service")
if service_url:
return service_url
except Exception as e:
logger.warning(f"Service discovery failed: {e}")
# Fall back to configured URL
return AUTH_SERVICE_URL
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
"""Prepare headers for forwarding using unified HeaderManager"""
# Use unified HeaderManager to get all headers
if request:
all_headers = header_manager.get_all_headers_for_proxy(request)
logger.debug(f"DEBUG: Added headers from HeaderManager: {list(all_headers.keys())}")
else:
# Fallback: convert headers to dict manually
all_headers = {}
if hasattr(headers, '_list'):
for k, v in headers.__dict__.get('_list', []):
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
elif hasattr(headers, 'raw'):
for k, v in headers.raw:
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
else:
# Headers is already a dict
all_headers = dict(headers)
# Debug logging
logger.info(f"📤 Forwarding headers - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
return all_headers
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
"""Prepare response headers"""
# Remove server-specific headers
filtered_headers = {
k: v for k, v in headers.items()
if k.lower() not in {'server', 'date'}
}
# Add CORS headers if needed
if settings.CORS_ORIGINS:
filtered_headers['Access-Control-Allow-Origin'] = '*'
filtered_headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
filtered_headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return filtered_headers
# Initialize proxy
auth_proxy = AuthProxy()
# ================================================================
# CATCH-ALL ROUTE for all auth and user endpoints
# ================================================================
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
async def proxy_auth_requests(path: str, request: Request):
"""Catch-all proxy for all auth and user requests"""
return await auth_proxy.forward_request(request.method, f"api/v1/auth/{path}", request)
# ================================================================
# HEALTH CHECK for auth service
# ================================================================
@router.get("/health")
async def auth_service_health():
"""Check auth service health"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{AUTH_SERVICE_URL}/health")
if response.status_code == 200:
return {
"status": "healthy",
"auth_service": "available",
"response_time_ms": response.elapsed.total_seconds() * 1000
}
else:
return {
"status": "unhealthy",
"auth_service": "error",
"status_code": response.status_code
}
except Exception as e:
logger.error(f"Auth service health check failed: {e}")
return {
"status": "unhealthy",
"auth_service": "unavailable",
"error": str(e)
}
# ================================================================
# CLEANUP
# ================================================================
@router.on_event("shutdown")
async def cleanup():
"""Cleanup resources"""
await auth_proxy.client.aclose()

View File

@@ -0,0 +1,58 @@
"""
Demo Session Routes - Proxy to demo-session service
"""
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import JSONResponse
import httpx
import structlog
from app.core.config import settings
from app.core.header_manager import header_manager
logger = structlog.get_logger()
router = APIRouter()
@router.api_route("/demo/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
async def proxy_demo_service(path: str, request: Request):
"""
Proxy all demo requests to the demo-session service
These endpoints are public and don't require authentication
"""
# Build the target URL
demo_service_url = settings.DEMO_SESSION_SERVICE_URL.rstrip('/')
target_url = f"{demo_service_url}/api/v1/demo/{path}"
# Get request body
body = None
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.request(
method=request.method,
url=target_url,
headers=headers,
params=request.query_params,
content=body
)
# Return the response
return JSONResponse(
content=response.json() if response.content else {},
status_code=response.status_code,
headers=dict(response.headers)
)
except httpx.RequestError as e:
logger.error("Failed to proxy to demo-session service", error=str(e), url=target_url)
raise HTTPException(status_code=503, detail="Demo service unavailable")
except Exception as e:
logger.error("Unexpected error proxying to demo-session service", error=str(e))
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -0,0 +1,71 @@
# gateway/app/routes/geocoding.py
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import JSONResponse
import httpx
import structlog
from app.core.config import settings
from app.core.header_manager import header_manager
logger = structlog.get_logger()
router = APIRouter()
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
async def proxy_geocoding(request: Request, path: str):
"""
Proxies all geocoding requests to the External Service geocoding endpoints.
Forwards requests from /api/v1/geocoding/* to external-service:8000/api/v1/geocoding/*
"""
try:
# Construct the external service URL
external_url = f"{settings.EXTERNAL_SERVICE_URL}/api/v1/geocoding/{path}"
# Get request body for POST/PUT/PATCH
body = None
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Make the proxied request
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.request(
method=request.method,
url=external_url,
params=request.query_params,
headers=headers,
content=body
)
# Return the response from external service
return JSONResponse(
content=response.json() if response.text else None,
status_code=response.status_code,
headers=dict(response.headers)
)
except httpx.RequestError as exc:
logger.error("External service geocoding request failed", error=str(exc), path=path)
raise HTTPException(
status_code=503,
detail=f"Geocoding service unavailable: {exc}"
)
except httpx.HTTPStatusError as exc:
logger.error(
f"External service geocoding responded with error {exc.response.status_code}",
detail=exc.response.text,
path=path
)
raise HTTPException(
status_code=exc.response.status_code,
detail=f"Geocoding service error: {exc.response.text}"
)
except Exception as exc:
logger.error("Unexpected error in geocoding proxy", error=str(exc), path=path)
raise HTTPException(
status_code=500,
detail="Internal server error in geocoding proxy"
)

View File

@@ -0,0 +1,61 @@
# gateway/app/routes/nominatim.py
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import JSONResponse
import httpx
import structlog
from app.core.config import settings
logger = structlog.get_logger()
router = APIRouter()
@router.get("/search")
async def proxy_nominatim_search(request: Request):
"""
Proxies requests to the Nominatim geocoding search API.
"""
# Handle OPTIONS requests directly for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
}
)
try:
# Construct the internal Nominatim URL
# All query parameters from the client request are forwarded
nominatim_url = f"{settings.NOMINATIM_SERVICE_URL}/nominatim/search"
# httpx client for making async HTTP requests
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(
nominatim_url,
params=request.query_params # Forward all query parameters from frontend
)
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
# Return the JSON response from Nominatim directly
return JSONResponse(content=response.json())
except httpx.RequestError as exc:
logger.error("Nominatim service request failed", error=str(exc))
raise HTTPException(
status_code=503,
detail=f"Nominatim service unavailable: {exc}"
)
except httpx.HTTPStatusError as exc:
logger.error(f"Nominatim service responded with error {exc.response.status_code}",
detail=exc.response.text)
raise HTTPException(
status_code=exc.response.status_code,
detail=f"Nominatim service error: {exc.response.text}"
)
except Exception as exc:
logger.error("Unexpected error in Nominatim proxy", error=str(exc))
raise HTTPException(status_code=500, detail="Internal server error in Nominatim proxy")

View File

@@ -0,0 +1,88 @@
"""
POI Context Proxy Router
Forwards all POI context requests to the External Service
"""
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import JSONResponse
import httpx
import structlog
from app.core.config import settings
from app.core.header_manager import header_manager
logger = structlog.get_logger()
router = APIRouter()
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
async def proxy_poi_context(request: Request, path: str):
"""
Proxies all POI context requests to the External Service.
Forwards requests from /api/v1/poi-context/* to external-service:8000/api/v1/poi-context/*
Args:
request: The incoming request
path: The path after /api/v1/poi-context/
Returns:
JSONResponse with the response from the external service
Raises:
HTTPException: If the external service is unavailable or returns an error
"""
try:
# Construct the external service URL
external_url = f"{settings.EXTERNAL_SERVICE_URL}/poi-context/{path}"
logger.debug("Proxying POI context request",
method=request.method,
path=path,
external_url=external_url)
# Get request body for POST/PUT/PATCH requests
body = None
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Make the request to the external service
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.request(
method=request.method,
url=external_url,
params=request.query_params,
headers=headers,
content=body
)
logger.debug("POI context proxy response",
status_code=response.status_code,
path=path)
# Return the response from the external service
return JSONResponse(
content=response.json() if response.text else None,
status_code=response.status_code,
headers=dict(response.headers)
)
except httpx.RequestError as exc:
logger.error("External service POI request failed",
error=str(exc),
path=path,
external_url=external_url)
raise HTTPException(
status_code=503,
detail=f"POI service unavailable: {exc}"
)
except Exception as exc:
logger.error("Unexpected error in POI proxy",
error=str(exc),
path=path)
raise HTTPException(
status_code=500,
detail="Internal server error in POI proxy"
)

89
gateway/app/routes/pos.py Normal file
View File

@@ -0,0 +1,89 @@
"""
POS routes for API Gateway - Global POS endpoints
"""
from fastapi import APIRouter, Request, Response, HTTPException
from fastapi.responses import JSONResponse
import httpx
import logging
from app.core.config import settings
from app.core.header_manager import header_manager
logger = logging.getLogger(__name__)
router = APIRouter()
# ================================================================
# GLOBAL POS ENDPOINTS (No tenant context required)
# ================================================================
@router.api_route("/supported-systems", methods=["GET", "OPTIONS"])
async def proxy_supported_systems(request: Request):
"""Proxy supported POS systems request to POS service"""
target_path = "/api/v1/pos/supported-systems"
return await _proxy_to_pos_service(request, target_path)
# ================================================================
# PROXY HELPER FUNCTIONS
# ================================================================
async def _proxy_to_pos_service(request: Request, target_path: str):
"""Proxy request to POS service"""
# Handle OPTIONS requests directly for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400"
}
)
try:
url = f"{settings.POS_SERVICE_URL}{target_path}"
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Add query parameters
params = dict(request.query_params)
timeout_config = httpx.Timeout(
connect=30.0,
read=60.0,
write=30.0,
pool=30.0
)
async with httpx.AsyncClient(timeout=timeout_config) as client:
response = await client.request(
method=request.method,
url=url,
headers=headers,
params=params
)
# Handle different response types
if response.headers.get("content-type", "").startswith("application/json"):
try:
content = response.json()
except:
content = {"message": "Invalid JSON response from service"}
else:
content = response.text
return JSONResponse(
status_code=response.status_code,
content=content
)
except Exception as e:
logger.error(f"Unexpected error proxying to POS service {target_path}: {e}")
raise HTTPException(
status_code=500,
detail="Internal gateway error"
)

View File

@@ -0,0 +1,116 @@
"""
Registration routes for API Gateway - Handles registration-specific endpoints
These endpoints don't require a tenant ID and are used during the registration flow
"""
from fastapi import APIRouter, Request, Response, HTTPException
from fastapi.responses import JSONResponse
import httpx
import logging
from typing import Optional
from app.core.config import settings
from app.core.header_manager import header_manager
logger = logging.getLogger(__name__)
router = APIRouter()
# ================================================================
# REGISTRATION ENDPOINTS - Direct routing to tenant service
# These endpoints are called during registration before a tenant exists
# ================================================================
@router.post("/registration-payment-setup")
async def proxy_registration_payment_setup(request: Request):
"""Proxy registration payment setup request to tenant service"""
return await _proxy_to_tenant_service(request, "/api/v1/tenants/registration-payment-setup")
@router.post("/verify-and-complete-registration")
async def proxy_verify_and_complete_registration(request: Request):
"""Proxy verification and registration completion to tenant service"""
return await _proxy_to_tenant_service(request, "/api/v1/tenants/verify-and-complete-registration")
@router.post("/payment-customers/create")
async def proxy_registration_customer_create(request: Request):
"""Proxy registration customer creation to tenant service"""
return await _proxy_to_tenant_service(request, "/api/v1/payment-customers/create")
@router.get("/setup-intents/{setup_intent_id}/verify")
async def proxy_registration_setup_intent_verify(request: Request, setup_intent_id: str):
"""Proxy registration setup intent verification to tenant service"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/setup-intents/{setup_intent_id}/verify")
# ================================================================
# PROXY HELPER FUNCTIONS
# ================================================================
async def _proxy_to_tenant_service(request: Request, target_path: str):
"""Generic proxy function with enhanced error handling"""
# Handle OPTIONS requests directly for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID, Stripe-Signature",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400"
}
)
try:
url = f"{settings.TENANT_SERVICE_URL}{target_path}"
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Debug logging
logger.info(f"Forwarding registration request to {url}")
# Get request body if present
body = None
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
# Add query parameters
params = dict(request.query_params)
timeout_config = httpx.Timeout(
connect=30.0,
read=60.0,
write=30.0,
pool=30.0
)
async with httpx.AsyncClient(timeout=timeout_config) as client:
response = await client.request(
method=request.method,
url=url,
headers=headers,
content=body,
params=params
)
# Handle different response types
if response.headers.get("content-type", "").startswith("application/json"):
try:
content = response.json()
except:
content = {"message": "Invalid JSON response from service"}
else:
content = response.text
return JSONResponse(
status_code=response.status_code,
content=content
)
except Exception as e:
logger.error(f"Unexpected error proxying registration request to {settings.TENANT_SERVICE_URL}{target_path}: {e}")
raise HTTPException(
status_code=500,
detail="Internal gateway error"
)

View File

@@ -0,0 +1,317 @@
"""
Subscription routes for API Gateway - Direct subscription endpoints
New URL Pattern Architecture:
- Registration: /registration/payment-setup, /registration/complete, /registration/state/{state_id}
- Tenant Subscription: /tenants/{tenant_id}/subscription/*
- Setup Intents: /setup-intents/{setup_intent_id}/verify
- Payment Customers: /payment-customers/create
- Plans: /plans (public)
"""
from fastapi import APIRouter, Request, Response, HTTPException, Path
from fastapi.responses import JSONResponse
import httpx
import logging
from typing import Optional
from app.core.config import settings
from app.core.header_manager import header_manager
logger = logging.getLogger(__name__)
router = APIRouter()
# ================================================================
# PUBLIC ENDPOINTS (No Authentication)
# ================================================================
@router.api_route("/plans", methods=["GET", "OPTIONS"])
async def proxy_plans(request: Request):
"""Proxy plans request to tenant service"""
target_path = "/plans"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/plans/{tier}", methods=["GET", "OPTIONS"])
async def proxy_plan_details(request: Request, tier: str = Path(...)):
"""Proxy specific plan details request to tenant service"""
target_path = f"/plans/{tier}"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/plans/{tier}/features", methods=["GET", "OPTIONS"])
async def proxy_plan_features(request: Request, tier: str = Path(...)):
"""Proxy plan features request to tenant service"""
target_path = f"/plans/{tier}/features"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/plans/{tier}/limits", methods=["GET", "OPTIONS"])
async def proxy_plan_limits(request: Request, tier: str = Path(...)):
"""Proxy plan limits request to tenant service"""
target_path = f"/plans/{tier}/limits"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/plans/compare", methods=["GET", "OPTIONS"])
async def proxy_plan_compare(request: Request):
"""Proxy plan comparison request to tenant service"""
target_path = "/plans/compare"
return await _proxy_to_tenant_service(request, target_path)
# ================================================================
# REGISTRATION FLOW ENDPOINTS (No Tenant Context)
# ================================================================
@router.api_route("/registration/payment-setup", methods=["POST", "OPTIONS"])
async def proxy_registration_payment_setup(request: Request):
"""Proxy registration payment setup request to tenant service"""
target_path = "/api/v1/registration/payment-setup"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/registration/complete", methods=["POST", "OPTIONS"])
async def proxy_registration_complete(request: Request):
"""Proxy registration completion request to tenant service"""
target_path = "/api/v1/registration/complete"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/registration/state/{state_id}", methods=["GET", "OPTIONS"])
async def proxy_registration_state(request: Request, state_id: str = Path(...)):
"""Proxy registration state request to tenant service"""
target_path = f"/api/v1/registration/state/{state_id}"
return await _proxy_to_tenant_service(request, target_path)
# ================================================================
# TENANT SUBSCRIPTION STATUS ENDPOINTS
# ================================================================
@router.api_route("/tenants/{tenant_id}/subscription/status", methods=["GET", "OPTIONS"])
async def proxy_subscription_status(request: Request, tenant_id: str = Path(...)):
"""Proxy subscription status request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/status"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/tenants/{tenant_id}/subscription/details", methods=["GET", "OPTIONS"])
async def proxy_subscription_details(request: Request, tenant_id: str = Path(...)):
"""Proxy subscription details request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/details"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/tenants/{tenant_id}/subscription/tier", methods=["GET", "OPTIONS"])
async def proxy_subscription_tier(request: Request, tenant_id: str = Path(...)):
"""Proxy subscription tier request to tenant service (cached)"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/tier"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/tenants/{tenant_id}/subscription/limits", methods=["GET", "OPTIONS"])
async def proxy_subscription_limits(request: Request, tenant_id: str = Path(...)):
"""Proxy subscription limits request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/limits"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/tenants/{tenant_id}/subscription/usage", methods=["GET", "OPTIONS"])
async def proxy_subscription_usage(request: Request, tenant_id: str = Path(...)):
"""Proxy subscription usage request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/usage"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/tenants/{tenant_id}/subscription/features/{feature}", methods=["GET", "OPTIONS"])
async def proxy_subscription_feature(request: Request, tenant_id: str = Path(...), feature: str = Path(...)):
"""Proxy subscription feature check request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/features/{feature}"
return await _proxy_to_tenant_service(request, target_path)
# ================================================================
# SUBSCRIPTION MANAGEMENT ENDPOINTS
# ================================================================
@router.api_route("/tenants/{tenant_id}/subscription/cancel", methods=["POST", "OPTIONS"])
async def proxy_subscription_cancel(request: Request, tenant_id: str = Path(...)):
"""Proxy subscription cancellation request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/cancel"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/tenants/{tenant_id}/subscription/reactivate", methods=["POST", "OPTIONS"])
async def proxy_subscription_reactivate(request: Request, tenant_id: str = Path(...)):
"""Proxy subscription reactivation request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/reactivate"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/tenants/{tenant_id}/subscription/validate-upgrade/{new_plan}", methods=["GET", "OPTIONS"])
async def proxy_validate_upgrade(request: Request, tenant_id: str = Path(...), new_plan: str = Path(...)):
"""Proxy plan upgrade validation request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/validate-upgrade/{new_plan}"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/tenants/{tenant_id}/subscription/upgrade", methods=["POST", "OPTIONS"])
async def proxy_subscription_upgrade(request: Request, tenant_id: str = Path(...)):
"""Proxy subscription upgrade request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/upgrade"
return await _proxy_to_tenant_service(request, target_path)
# ================================================================
# QUOTA & LIMIT CHECK ENDPOINTS
# ================================================================
@router.api_route("/tenants/{tenant_id}/subscription/limits/locations", methods=["GET", "OPTIONS"])
async def proxy_location_limits(request: Request, tenant_id: str = Path(...)):
"""Proxy location limits check request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/limits/locations"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/tenants/{tenant_id}/subscription/limits/products", methods=["GET", "OPTIONS"])
async def proxy_product_limits(request: Request, tenant_id: str = Path(...)):
"""Proxy product limits check request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/limits/products"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/tenants/{tenant_id}/subscription/limits/users", methods=["GET", "OPTIONS"])
async def proxy_user_limits(request: Request, tenant_id: str = Path(...)):
"""Proxy user limits check request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/limits/users"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/tenants/{tenant_id}/subscription/limits/recipes", methods=["GET", "OPTIONS"])
async def proxy_recipe_limits(request: Request, tenant_id: str = Path(...)):
"""Proxy recipe limits check request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/limits/recipes"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/tenants/{tenant_id}/subscription/limits/suppliers", methods=["GET", "OPTIONS"])
async def proxy_supplier_limits(request: Request, tenant_id: str = Path(...)):
"""Proxy supplier limits check request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/limits/suppliers"
return await _proxy_to_tenant_service(request, target_path)
# ================================================================
# PAYMENT MANAGEMENT ENDPOINTS
# ================================================================
@router.api_route("/tenants/{tenant_id}/subscription/payment-method", methods=["GET", "POST", "OPTIONS"])
async def proxy_payment_method(request: Request, tenant_id: str = Path(...)):
"""Proxy payment method request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/payment-method"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/tenants/{tenant_id}/subscription/invoices", methods=["GET", "OPTIONS"])
async def proxy_invoices(request: Request, tenant_id: str = Path(...)):
"""Proxy invoices request to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/subscription/invoices"
return await _proxy_to_tenant_service(request, target_path)
# ================================================================
# SETUP INTENT VERIFICATION
# ================================================================
@router.api_route("/setup-intents/{setup_intent_id}/verify", methods=["GET", "OPTIONS"])
async def proxy_setup_intent_verify(request: Request, setup_intent_id: str = Path(...)):
"""Proxy SetupIntent verification request to tenant service"""
target_path = f"/api/v1/setup-intents/{setup_intent_id}/verify"
return await _proxy_to_tenant_service(request, target_path)
# ================================================================
# PAYMENT CUSTOMER MANAGEMENT
# ================================================================
@router.api_route("/payment-customers/create", methods=["POST", "OPTIONS"])
async def proxy_payment_customer_create(request: Request):
"""Proxy payment customer creation request to tenant service"""
target_path = "/api/v1/payment-customers/create"
return await _proxy_to_tenant_service(request, target_path)
# ================================================================
# USAGE FORECAST ENDPOINTS
# ================================================================
@router.api_route("/usage-forecast", methods=["GET", "OPTIONS"])
async def proxy_usage_forecast(request: Request):
"""Proxy usage forecast request to tenant service"""
target_path = "/api/v1/usage-forecast"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/usage-forecast/track-usage", methods=["POST", "OPTIONS"])
async def proxy_track_usage(request: Request):
"""Proxy track usage request to tenant service"""
target_path = "/api/v1/usage-forecast/track-usage"
return await _proxy_to_tenant_service(request, target_path)
# ================================================================
# PROXY HELPER FUNCTIONS
# ================================================================
async def _proxy_to_tenant_service(request: Request, target_path: str):
"""Proxy request to tenant service"""
return await _proxy_request(request, target_path, settings.TENANT_SERVICE_URL)
async def _proxy_request(request: Request, target_path: str, service_url: str):
"""Generic proxy function with enhanced error handling"""
# Handle OPTIONS requests directly for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400"
}
)
try:
url = f"{service_url}{target_path}"
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Debug logging
user_context = getattr(request.state, 'user', None)
service_context = getattr(request.state, 'service', None)
if user_context:
logger.info(f"Forwarding subscription request to {url} with user context: user_id={user_context.get('user_id')}, email={user_context.get('email')}, subscription_tier={user_context.get('subscription_tier', 'not_set')}")
elif service_context:
logger.debug(f"Forwarding subscription request to {url} with service context: service_name={service_context.get('service_name')}, user_type=service")
else:
logger.warning(f"No user or service context available when forwarding subscription request to {url}")
# Get request body if present
body = None
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
# Add query parameters
params = dict(request.query_params)
timeout_config = httpx.Timeout(
connect=30.0,
read=60.0,
write=30.0,
pool=30.0
)
async with httpx.AsyncClient(timeout=timeout_config) as client:
response = await client.request(
method=request.method,
url=url,
headers=headers,
content=body,
params=params
)
# Handle different response types
if response.headers.get("content-type", "").startswith("application/json"):
try:
content = response.json()
except:
content = {"message": "Invalid JSON response from service"}
else:
content = response.text
return JSONResponse(
status_code=response.status_code,
content=content
)
except Exception as e:
logger.error(f"Unexpected error proxying subscription request to {service_url}{target_path}: {e}")
raise HTTPException(
status_code=500,
detail="Internal gateway error"
)

View File

@@ -0,0 +1,303 @@
"""
Telemetry routes for API Gateway - Handles frontend telemetry data
This module provides endpoints for:
- Receiving OpenTelemetry traces from frontend
- Proxying traces to Signoz OTel collector
- Providing a secure, authenticated endpoint for frontend telemetry
"""
from fastapi import APIRouter, Request, HTTPException, status
from fastapi.responses import JSONResponse, Response
import httpx
import logging
import os
from typing import Optional
from app.core.config import settings
from app.core.header_manager import header_manager
from shared.monitoring.metrics import MetricsCollector, create_metrics_collector
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/telemetry", tags=["telemetry"])
# Get Signoz OTel collector endpoint from environment or use default
SIGNOZ_OTEL_COLLECTOR = os.getenv(
"SIGNOZ_OTEL_COLLECTOR_URL",
"http://signoz-otel-collector.bakery-ia.svc.cluster.local:4318"
)
@router.post("/v1/traces")
async def receive_frontend_traces(request: Request):
"""
Receive OpenTelemetry traces from frontend and proxy to Signoz
This endpoint:
- Accepts OTLP trace data from frontend
- Validates the request
- Proxies to Signoz OTel collector
- Handles errors gracefully
"""
# Handle OPTIONS requests for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400"
}
)
try:
# Get the trace data from the request
body = await request.body()
if not body:
logger.warning("Received empty trace data from frontend")
return JSONResponse(
status_code=400,
content={"error": "Empty trace data"}
)
# Log the trace reception (without sensitive data)
logger.info(
"Received frontend traces, content_length=%s, content_type=%s, user_agent=%s",
len(body),
request.headers.get("content-type"),
request.headers.get("user-agent")
)
# Forward to Signoz OTel collector
target_url = f"{SIGNOZ_OTEL_COLLECTOR}/v1/traces"
# Set up headers for the Signoz collector
forward_headers = {
"Content-Type": request.headers.get("content-type", "application/json"),
"User-Agent": "bakery-gateway/1.0",
"X-Forwarded-For": request.headers.get("x-forwarded-for", "frontend"),
"X-Tenant-ID": request.headers.get("x-tenant-id", "unknown")
}
# Add authentication if configured
signoz_auth_token = os.getenv("SIGNOZ_AUTH_TOKEN")
if signoz_auth_token:
forward_headers["Authorization"] = f"Bearer {signoz_auth_token}"
# Send to Signoz collector
timeout_config = httpx.Timeout(
connect=5.0,
read=10.0,
write=5.0,
pool=5.0
)
async with httpx.AsyncClient(timeout=timeout_config) as client:
response = await client.post(
url=target_url,
content=body,
headers=forward_headers
)
# Log the response from Signoz
logger.info(
"Forwarded traces to Signoz, signoz_status=%s, signoz_response_time=%s",
response.status_code,
response.elapsed.total_seconds()
)
# Return success response to frontend
return JSONResponse(
status_code=200,
content={
"message": "Traces received and forwarded to Signoz",
"signoz_status": response.status_code,
"trace_count": 1 # We don't know exact count without parsing
}
)
except httpx.HTTPStatusError as e:
logger.error(
"Signoz collector returned error, status_code=%s, error_message=%s",
e.response.status_code,
str(e)
)
return JSONResponse(
status_code=502,
content={
"error": "Signoz collector error",
"details": str(e),
"signoz_status": e.response.status_code
}
)
except httpx.RequestError as e:
logger.error(
"Failed to connect to Signoz collector, error=%s, collector_url=%s",
str(e),
SIGNOZ_OTEL_COLLECTOR
)
return JSONResponse(
status_code=503,
content={
"error": "Signoz collector unavailable",
"details": str(e)
}
)
except Exception as e:
logger.error(
"Unexpected error processing traces, error=%s, error_type=%s",
str(e),
type(e).__name__
)
return JSONResponse(
status_code=500,
content={
"error": "Internal server error",
"details": str(e)
}
)
@router.post("/v1/metrics")
async def receive_frontend_metrics(request: Request):
"""
Receive OpenTelemetry metrics from frontend and proxy to Signoz
"""
# Handle OPTIONS requests for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400"
}
)
try:
body = await request.body()
if not body:
return JSONResponse(
status_code=400,
content={"error": "Empty metrics data"}
)
logger.info(
"Received frontend metrics, content_length=%s, content_type=%s",
len(body),
request.headers.get("content-type")
)
# Forward to Signoz OTel collector
target_url = f"{SIGNOZ_OTEL_COLLECTOR}/v1/metrics"
forward_headers = {
"Content-Type": request.headers.get("content-type", "application/json"),
"User-Agent": "bakery-gateway/1.0",
"X-Forwarded-For": request.headers.get("x-forwarded-for", "frontend"),
"X-Tenant-ID": request.headers.get("x-tenant-id", "unknown")
}
# Add authentication if configured
signoz_auth_token = os.getenv("SIGNOZ_AUTH_TOKEN")
if signoz_auth_token:
forward_headers["Authorization"] = f"Bearer {signoz_auth_token}"
timeout_config = httpx.Timeout(
connect=5.0,
read=10.0,
write=5.0,
pool=5.0
)
async with httpx.AsyncClient(timeout=timeout_config) as client:
response = await client.post(
url=target_url,
content=body,
headers=forward_headers
)
logger.info(
"Forwarded metrics to Signoz, signoz_status=%s",
response.status_code
)
return JSONResponse(
status_code=200,
content={
"message": "Metrics received and forwarded to Signoz",
"signoz_status": response.status_code
}
)
except Exception as e:
logger.error(
"Error processing metrics, error=%s",
str(e)
)
return JSONResponse(
status_code=500,
content={
"error": "Internal server error",
"details": str(e)
}
)
@router.get("/health")
async def telemetry_health():
"""
Health check endpoint for telemetry service
"""
return JSONResponse(
status_code=200,
content={
"status": "healthy",
"service": "telemetry-gateway",
"signoz_collector": SIGNOZ_OTEL_COLLECTOR
}
)
# Initialize metrics for this module
try:
metrics_collector = create_metrics_collector("gateway-telemetry")
except Exception as e:
logger.error("Failed to create metrics collector, error=%s", str(e))
metrics_collector = None
@router.on_event("startup")
async def startup_event():
"""Initialize telemetry metrics on startup"""
try:
if metrics_collector:
# Register telemetry-specific metrics
metrics_collector.register_counter(
"gateway_telemetry_traces_received",
"Number of trace batches received from frontend"
)
metrics_collector.register_counter(
"gateway_telemetry_metrics_received",
"Number of metric batches received from frontend"
)
metrics_collector.register_counter(
"gateway_telemetry_errors",
"Number of telemetry processing errors"
)
logger.info(
"Telemetry gateway initialized, signoz_collector=%s",
SIGNOZ_OTEL_COLLECTOR
)
except Exception as e:
logger.error(
"Failed to initialize telemetry metrics, error=%s",
str(e)
)

View File

@@ -0,0 +1,822 @@
# gateway/app/routes/tenant.py - COMPLETELY UPDATED
"""
Tenant routes for API Gateway - Handles all tenant-scoped endpoints
"""
from fastapi import APIRouter, Request, Response, HTTPException, Path
from fastapi.responses import JSONResponse
import httpx
import logging
from typing import Optional
from app.core.config import settings
from app.core.header_manager import header_manager
logger = logging.getLogger(__name__)
router = APIRouter()
# ================================================================
# TENANT MANAGEMENT ENDPOINTS
# ================================================================
@router.post("/register")
async def create_tenant(request: Request):
"""Proxy tenant creation to tenant service"""
return await _proxy_to_tenant_service(request, "/api/v1/tenants/register")
@router.get("/{tenant_id}")
async def get_tenant(request: Request, tenant_id: str = Path(...)):
"""Get specific tenant details"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}")
@router.put("/{tenant_id}")
async def update_tenant(request: Request, tenant_id: str = Path(...)):
"""Update tenant details"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}")
@router.get("/{tenant_id}/members")
async def get_tenant_members(request: Request, tenant_id: str = Path(...)):
"""Get tenant members"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/members")
@router.post("/{tenant_id}/members")
async def add_tenant_member(request: Request, tenant_id: str = Path(...)):
"""Add a team member to tenant"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/members")
@router.post("/{tenant_id}/members/with-user")
async def add_tenant_member_with_user(request: Request, tenant_id: str = Path(...)):
"""Add a team member to tenant with user creation"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/members/with-user")
@router.put("/{tenant_id}/members/{member_user_id}/role")
async def update_member_role(request: Request, tenant_id: str = Path(...), member_user_id: str = Path(...)):
"""Update team member role"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/members/{member_user_id}/role")
@router.delete("/{tenant_id}/members/{member_user_id}")
async def remove_tenant_member(request: Request, tenant_id: str = Path(...), member_user_id: str = Path(...)):
"""Remove team member from tenant"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/members/{member_user_id}")
@router.post("/{tenant_id}/transfer-ownership")
async def transfer_tenant_ownership(request: Request, tenant_id: str = Path(...)):
"""Transfer tenant ownership to another admin"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/transfer-ownership")
@router.get("/{tenant_id}/admins")
async def get_tenant_admins(request: Request, tenant_id: str = Path(...)):
"""Get all admins for a tenant"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/admins")
@router.get("/{tenant_id}/hierarchy")
async def get_tenant_hierarchy(request: Request, tenant_id: str = Path(...)):
"""Get tenant hierarchy information"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/hierarchy")
@router.api_route("/{tenant_id}/children", methods=["GET", "OPTIONS"])
async def get_tenant_children(request: Request, tenant_id: str = Path(...)):
"""Get tenant children"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/children")
@router.api_route("/{tenant_id}/bulk-children", methods=["POST", "OPTIONS"])
async def proxy_bulk_children(request: Request, tenant_id: str = Path(...)):
"""Proxy bulk children creation requests to tenant service"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/bulk-children")
@router.api_route("/{tenant_id}/children/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_children(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant children requests to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/children/{path}".rstrip("/")
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/{tenant_id}/access/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_access(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant access requests to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/access/{path}".rstrip("/")
return await _proxy_to_tenant_service(request, target_path)
@router.get("/{tenant_id}/my-access")
async def get_tenant_my_access(request: Request, tenant_id: str = Path(...)):
"""Get current user's access level for a tenant"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/my-access")
@router.get("/user/{user_id}")
async def get_user_tenants(request: Request, user_id: str = Path(...)):
"""Get all tenant memberships for a user (admin only)"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/users/{user_id}")
@router.get("/user/{user_id}/owned")
async def get_user_owned_tenants(request: Request, user_id: str = Path(...)):
"""Get all tenants owned by a user"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/user/{user_id}/owned")
@router.get("/user/{user_id}/tenants")
async def get_user_all_tenants(request: Request, user_id: str = Path(...)):
"""Get all tenants accessible by a user (both owned and member tenants)"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/user/{user_id}/tenants")
@router.get("/users/{user_id}/primary-tenant")
async def get_user_primary_tenant(request: Request, user_id: str = Path(...)):
"""Get the primary tenant for a user (used by auth service for subscription validation)"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/users/{user_id}/primary-tenant")
@router.delete("/user/{user_id}/memberships")
async def delete_user_tenants(request: Request, user_id: str = Path(...)):
"""Get all tenant memberships for a user (admin only)"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/user/{user_id}/memberships")
# ================================================================
# TENANT SETTINGS ENDPOINTS
# ================================================================
@router.get("/{tenant_id}/settings")
async def get_tenant_settings(request: Request, tenant_id: str = Path(...)):
"""Get all settings for a tenant"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/settings")
@router.put("/{tenant_id}/settings")
async def update_tenant_settings(request: Request, tenant_id: str = Path(...)):
"""Update tenant settings"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/settings")
@router.get("/{tenant_id}/settings/{category}")
async def get_category_settings(request: Request, tenant_id: str = Path(...), category: str = Path(...)):
"""Get settings for a specific category"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/settings/{category}")
@router.put("/{tenant_id}/settings/{category}")
async def update_category_settings(request: Request, tenant_id: str = Path(...), category: str = Path(...)):
"""Update settings for a specific category"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/settings/{category}")
@router.post("/{tenant_id}/settings/{category}/reset")
async def reset_category_settings(request: Request, tenant_id: str = Path(...), category: str = Path(...)):
"""Reset a category to default values"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/settings/{category}/reset")
# ================================================================
# TENANT SUBSCRIPTION ENDPOINTS
# ================================================================
# NOTE: All subscription endpoints have been moved to gateway/app/routes/subscription.py
# as part of the architecture redesign for better separation of concerns.
# This wildcard route has been removed to avoid conflicts with the new specific routes.
@router.api_route("/subscriptions/plans", methods=["GET", "OPTIONS"])
async def proxy_available_plans(request: Request):
"""Proxy available plans request to tenant service"""
target_path = "/api/v1/plans/available"
return await _proxy_to_tenant_service(request, target_path)
# ================================================================
# BATCH OPERATIONS ENDPOINTS
# IMPORTANT: Route order matters! Keep specific routes before wildcards:
# 1. Exact matches first (/batch/sales-summary)
# 2. Wildcard paths second (/batch{path:path})
# 3. Tenant-scoped wildcards last (/{tenant_id}/batch{path:path})
# ================================================================
@router.api_route("/batch/sales-summary", methods=["POST"])
async def proxy_batch_sales_summary(request: Request):
"""Proxy batch sales summary request to sales service"""
target_path = "/api/v1/batch/sales-summary"
return await _proxy_to_sales_service(request, target_path)
@router.api_route("/batch/production-summary", methods=["POST"])
async def proxy_batch_production_summary(request: Request):
"""Proxy batch production summary request to production service"""
target_path = "/api/v1/batch/production-summary"
return await _proxy_to_production_service(request, target_path)
@router.api_route("/batch{path:path}", methods=["POST", "GET", "PUT", "DELETE", "OPTIONS"])
async def proxy_batch_operations(request: Request, path: str = ""):
"""Proxy batch operations that span multiple tenants to appropriate services"""
# For batch operations, route based on the path after /batch/
if path.startswith("/sales-summary"):
# Route batch sales summary to sales service
# The sales service batch endpoints are at /api/v1/batch/... not /api/v1/sales/batch/...
target_path = f"/api/v1/batch{path}"
return await _proxy_to_sales_service(request, target_path)
elif path.startswith("/production-summary"):
# Route batch production summary to production service
# The production service batch endpoints are at /api/v1/batch/... not /api/v1/production/batch/...
target_path = f"/api/v1/batch{path}"
return await _proxy_to_production_service(request, target_path)
else:
# Default to sales service for other batch operations
# The service batch endpoints are at /api/v1/batch/... not /api/v1/sales/batch/...
target_path = f"/api/v1/batch{path}"
return await _proxy_to_sales_service(request, target_path)
@router.api_route("/{tenant_id}/batch{path:path}", methods=["POST", "GET", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_batch_operations(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant-scoped batch operations to appropriate services"""
# For tenant-scoped batch operations, route based on the path after /batch/
if path.startswith("/sales-summary"):
# Route tenant batch sales summary to sales service
# The sales service batch endpoints are at /api/v1/batch/... not /api/v1/sales/batch/...
target_path = f"/api/v1/batch{path}"
return await _proxy_to_sales_service(request, target_path)
elif path.startswith("/production-summary"):
# Route tenant batch production summary to production service
# The production service batch endpoints are at /api/v1/batch/... not /api/v1/production/batch/...
target_path = f"/api/v1/batch{path}"
return await _proxy_to_production_service(request, target_path)
else:
# Default to sales service for other tenant batch operations
# The service batch endpoints are at /api/v1/batch/... not /api/v1/sales/batch/...
target_path = f"/api/v1/batch{path}"
return await _proxy_to_sales_service(request, target_path)
# ================================================================
# TENANT-SCOPED DATA SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/sales{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_all_tenant_sales_alternative(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy all tenant sales requests - handles both base and sub-paths"""
base_path = f"/api/v1/tenants/{tenant_id}/sales"
# If path is empty or just "/", use base path
if not path or path == "/" or path == "":
target_path = base_path
else:
# Ensure path starts with "/"
if not path.startswith("/"):
path = "/" + path
target_path = base_path + path
return await _proxy_to_sales_service(request, target_path)
@router.api_route("/{tenant_id}/enterprise/batch{path:path}", methods=["POST", "GET", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_enterprise_batch(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy enterprise batch requests (spanning multiple tenants within an enterprise) to appropriate services"""
# Forward to orchestrator service for enterprise-level operations
target_path = f"/api/v1/tenants/{tenant_id}/enterprise/batch{path}".rstrip("/")
return await _proxy_to_orchestrator_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/weather/{path:path}", methods=["GET", "POST", "OPTIONS"])
async def proxy_tenant_weather(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant weather requests to external service"""
target_path = f"/api/v1/tenants/{tenant_id}/weather/{path}".rstrip("/")
return await _proxy_to_external_service(request, target_path)
@router.api_route("/{tenant_id}/traffic/{path:path}", methods=["GET", "POST", "OPTIONS"])
async def proxy_tenant_traffic(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant traffic requests to external service"""
target_path = f"/api/v1/tenants/{tenant_id}/traffic/{path}".rstrip("/")
return await _proxy_to_external_service(request, target_path)
@router.api_route("/{tenant_id}/external/{path:path}", methods=["GET", "POST", "OPTIONS"])
async def proxy_tenant_external(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant external service requests (v2.0 city-based optimized endpoints)"""
# Route to external service with normal path structure
target_path = f"/api/v1/tenants/{tenant_id}/external/{path}".rstrip("/")
return await _proxy_to_external_service(request, target_path)
# Service-specific analytics routes (must come BEFORE the general analytics route)
@router.api_route("/{tenant_id}/procurement/analytics/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_procurement_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant procurement analytics requests to procurement service"""
target_path = f"/api/v1/tenants/{tenant_id}/procurement/analytics/{path}".rstrip("/")
return await _proxy_to_procurement_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/inventory/analytics/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_inventory_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant inventory analytics requests to inventory service"""
target_path = f"/api/v1/tenants/{tenant_id}/inventory/analytics/{path}".rstrip("/")
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/production/analytics/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_production_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant production analytics requests to production service"""
target_path = f"/api/v1/tenants/{tenant_id}/production/analytics/{path}".rstrip("/")
return await _proxy_to_production_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/sales/analytics/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_sales_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant sales analytics requests to sales service"""
target_path = f"/api/v1/tenants/{tenant_id}/sales/analytics/{path}".rstrip("/")
return await _proxy_to_sales_service(request, target_path)
@router.api_route("/{tenant_id}/analytics/{path:path}", methods=["GET", "POST", "OPTIONS"])
async def proxy_tenant_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant analytics requests to sales service (fallback for non-service-specific analytics)"""
target_path = f"/api/v1/tenants/{tenant_id}/analytics/{path}".rstrip("/")
return await _proxy_to_sales_service(request, target_path)
# ================================================================
# TENANT-SCOPED AI INSIGHTS ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/insights{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"])
async def proxy_tenant_insights(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant AI insights requests to AI insights service"""
target_path = f"/api/v1/tenants/{tenant_id}/insights{path}".rstrip("/")
return await _proxy_to_ai_insights_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/onboarding/{path:path}", methods=["GET", "POST", "OPTIONS"])
async def proxy_tenant_onboarding(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant onboarding requests to tenant service"""
target_path = f"/api/v1/tenants/{tenant_id}/onboarding/{path}".rstrip("/")
return await _proxy_to_tenant_service(request, target_path)
# ================================================================
# TENANT-SCOPED TRAINING SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/training/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_training(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant training requests to training service"""
target_path = f"/api/v1/tenants/{tenant_id}/training/{path}".rstrip("/")
return await _proxy_to_training_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/models/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_models(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant model requests to training service"""
target_path = f"/api/v1/tenants/{tenant_id}/models/{path}".rstrip("/")
return await _proxy_to_training_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/statistics", methods=["GET", "OPTIONS"])
async def proxy_tenant_statistics(request: Request, tenant_id: str = Path(...)):
"""Proxy tenant statistics requests to training service"""
target_path = f"/api/v1/tenants/{tenant_id}/statistics"
return await _proxy_to_training_service(request, target_path, tenant_id=tenant_id)
# ================================================================
# TENANT-SCOPED FORECASTING SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/forecasting/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_forecasting(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant forecasting requests to forecasting service"""
target_path = f"/api/v1/tenants/{tenant_id}/forecasting/{path}".rstrip("/")
return await _proxy_to_forecasting_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/forecasting/enterprise/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_forecasting_enterprise(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant forecasting enterprise requests to forecasting service"""
target_path = f"/api/v1/tenants/{tenant_id}/forecasting/enterprise/{path}".rstrip("/")
return await _proxy_to_forecasting_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/forecasts/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_forecasts(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant forecast requests to forecasting service"""
target_path = f"/api/v1/tenants/{tenant_id}/forecasts/{path}".rstrip("/")
return await _proxy_to_forecasting_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/predictions/{path:path}", methods=["GET", "POST", "OPTIONS"])
async def proxy_tenant_predictions(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant prediction requests to forecasting service"""
target_path = f"/api/v1/tenants/{tenant_id}/predictions/{path}".rstrip("/")
return await _proxy_to_forecasting_service(request, target_path)
# ================================================================
# TENANT-SCOPED NOTIFICATION SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/notifications/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_notifications(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant notification requests to notification service"""
target_path = f"/api/v1/tenants/{tenant_id}/notifications/{path}".rstrip("/")
return await _proxy_to_notification_service(request, target_path)
# ================================================================
# TENANT-SCOPED ALERT ANALYTICS ENDPOINTS (Must come BEFORE inventory alerts)
# ================================================================
# Exact match for /alerts endpoint (without additional path)
@router.api_route("/{tenant_id}/alerts", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_alerts_list(request: Request, tenant_id: str = Path(...)):
"""Proxy tenant alerts list requests to alert processor service"""
target_path = f"/api/v1/tenants/{tenant_id}/alerts"
return await _proxy_to_alert_processor_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/alerts/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_alert_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant alert analytics requests to alert processor service"""
target_path = f"/api/v1/tenants/{tenant_id}/alerts/{path}".rstrip("/")
return await _proxy_to_alert_processor_service(request, target_path, tenant_id=tenant_id)
# ================================================================
# TENANT-SCOPED INVENTORY SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/alerts{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_alerts_inventory(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant alerts requests to inventory service (legacy/food-safety alerts)"""
target_path = f"/api/v1/tenants/{tenant_id}/alerts{path}".rstrip("/")
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/inventory/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_inventory(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant inventory requests to inventory service"""
# The inventory service expects /api/v1/tenants/{tenant_id}/inventory/{path}
# Keep the full path structure for inventory service
target_path = f"/api/v1/tenants/{tenant_id}/inventory/{path}".rstrip("/")
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
# Specific route for ingredients without additional path
@router.api_route("/{tenant_id}/ingredients", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_ingredients_base(request: Request, tenant_id: str = Path(...)):
"""Proxy tenant ingredient requests to inventory service (base path)"""
target_path = f"/api/v1/tenants/{tenant_id}/ingredients"
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/ingredients/count", methods=["GET"])
async def proxy_tenant_ingredients_count(request: Request, tenant_id: str = Path(...)):
"""Proxy tenant ingredient count requests to inventory service"""
# Inventory service uses RouteBuilder('inventory').build_base_route("ingredients/count")
# which generates /api/v1/tenants/{tenant_id}/inventory/ingredients/count
target_path = f"/api/v1/tenants/{tenant_id}/inventory/ingredients/count"
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/ingredients/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_ingredients_with_path(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant ingredient requests to inventory service (with additional path)"""
# The inventory service ingredient endpoints are now tenant-scoped: /api/v1/tenants/{tenant_id}/ingredients/{path}
# Keep the full tenant path structure
target_path = f"/api/v1/tenants/{tenant_id}/ingredients/{path}".rstrip("/")
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/stock/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_stock(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant stock requests to inventory service"""
# The inventory service stock endpoints are now tenant-scoped: /api/v1/tenants/{tenant_id}/stock/{path}
target_path = f"/api/v1/tenants/{tenant_id}/stock/{path}".rstrip("/")
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/dashboard/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_dashboard(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant dashboard requests to orchestrator service"""
# The orchestrator service dashboard endpoints are tenant-scoped: /api/v1/tenants/{tenant_id}/dashboard/{path}
target_path = f"/api/v1/tenants/{tenant_id}/dashboard/{path}".rstrip("/")
return await _proxy_to_orchestrator_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/transformations", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_transformations_base(request: Request, tenant_id: str = Path(...)):
"""Proxy tenant transformations requests to inventory service (base path)"""
target_path = f"/api/v1/tenants/{tenant_id}/transformations"
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/transformations/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_transformations_with_path(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant transformations requests to inventory service (with additional path)"""
# The inventory service transformations endpoints are tenant-scoped: /api/v1/tenants/{tenant_id}/transformations/{path}
target_path = f"/api/v1/tenants/{tenant_id}/transformations/{path}".rstrip("/")
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/sustainability/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_sustainability(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant sustainability requests to inventory service"""
# The inventory service sustainability endpoints are tenant-scoped: /api/v1/tenants/{tenant_id}/sustainability/{path}
target_path = f"/api/v1/tenants/{tenant_id}/sustainability/{path}".rstrip("/")
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
# ================================================================
# TENANT-SCOPED PRODUCTION SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/production/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_production(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant production requests to production service"""
target_path = f"/api/v1/tenants/{tenant_id}/production/{path}".rstrip("/")
return await _proxy_to_production_service(request, target_path, tenant_id=tenant_id)
# ================================================================
# TENANT-SCOPED ORCHESTRATOR SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/enterprise/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_enterprise(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant enterprise dashboard requests to orchestrator service"""
target_path = f"/api/v1/tenants/{tenant_id}/enterprise/{path}".rstrip("/")
return await _proxy_to_orchestrator_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/orchestrator/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_orchestrator(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant orchestrator requests to orchestrator service"""
target_path = f"/api/v1/tenants/{tenant_id}/orchestrator/{path}".rstrip("/")
return await _proxy_to_orchestrator_service(request, target_path, tenant_id=tenant_id)
# ================================================================
# TENANT-SCOPED ORDERS SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/orders", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_orders_base(request: Request, tenant_id: str = Path(...)):
"""Proxy tenant orders requests to orders service (base path)"""
target_path = f"/api/v1/tenants/{tenant_id}/orders"
return await _proxy_to_orders_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/orders/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_orders_with_path(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant orders requests to orders service (with additional path)"""
target_path = f"/api/v1/tenants/{tenant_id}/orders/{path}".rstrip("/")
return await _proxy_to_orders_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/customers/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_customers(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant customers requests to orders service"""
target_path = f"/api/v1/tenants/{tenant_id}/orders/customers/{path}".rstrip("/")
return await _proxy_to_orders_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/procurement/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_procurement(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant procurement requests to procurement service"""
# For all procurement routes, we need to maintain the /procurement/ part in the path
# The procurement service now uses standardized paths with RouteBuilder
target_path = f"/api/v1/tenants/{tenant_id}/procurement/{path}".rstrip("/")
return await _proxy_to_procurement_service(request, target_path, tenant_id=tenant_id)
# ================================================================
# TENANT-SCOPED SUPPLIER SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/suppliers", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_suppliers_base(request: Request, tenant_id: str = Path(...)):
"""Proxy tenant supplier requests to suppliers service (base path)"""
target_path = f"/api/v1/tenants/{tenant_id}/suppliers"
return await _proxy_to_suppliers_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/suppliers/count", methods=["GET"])
async def proxy_tenant_suppliers_count(request: Request, tenant_id: str = Path(...)):
"""Proxy tenant supplier count requests to suppliers service"""
# Suppliers service uses RouteBuilder('suppliers').build_operations_route("count")
# which generates /api/v1/tenants/{tenant_id}/suppliers/operations/count
target_path = f"/api/v1/tenants/{tenant_id}/suppliers/operations/count"
return await _proxy_to_suppliers_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/suppliers/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_suppliers_with_path(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant supplier requests to suppliers service (with additional path)"""
target_path = f"/api/v1/tenants/{tenant_id}/suppliers/{path}".rstrip("/")
return await _proxy_to_suppliers_service(request, target_path, tenant_id=tenant_id)
# NOTE: Purchase orders are now accessed via the main procurement route:
# /api/v1/tenants/{tenant_id}/procurement/purchase-orders/*
# Legacy route removed to enforce standardized structure
@router.api_route("/{tenant_id}/deliveries{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_deliveries(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant delivery requests to suppliers service"""
target_path = f"/api/v1/tenants/{tenant_id}/deliveries{path}".rstrip("/")
return await _proxy_to_suppliers_service(request, target_path, tenant_id=tenant_id)
# ================================================================
# TENANT-SCOPED LOCATIONS ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/locations", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_locations_base(request: Request, tenant_id: str = Path(...)):
"""Proxy tenant locations requests to tenant service (base path)"""
target_path = f"/api/v1/tenants/{tenant_id}/locations"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/{tenant_id}/locations/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_locations_with_path(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant locations requests to tenant service (with additional path)"""
target_path = f"/api/v1/tenants/{tenant_id}/locations/{path}".rstrip("/")
return await _proxy_to_tenant_service(request, target_path)
# ================================================================
# TENANT-SCOPED DISTRIBUTION SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/distribution/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_distribution(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant distribution requests to distribution service"""
target_path = f"/api/v1/tenants/{tenant_id}/distribution/{path}".rstrip("/")
return await _proxy_to_distribution_service(request, target_path, tenant_id=tenant_id)
# ================================================================
# TENANT-SCOPED RECIPES SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/recipes", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_recipes_base(request: Request, tenant_id: str = Path(...)):
"""Proxy tenant recipes requests to recipes service (base path)"""
target_path = f"/api/v1/tenants/{tenant_id}/recipes"
return await _proxy_to_recipes_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/recipes/count", methods=["GET"])
async def proxy_tenant_recipes_count(request: Request, tenant_id: str = Path(...)):
"""Proxy tenant recipes count requests to recipes service"""
target_path = f"/api/v1/tenants/{tenant_id}/recipes/count"
return await _proxy_to_recipes_service(request, target_path, tenant_id=tenant_id)
@router.api_route("/{tenant_id}/recipes/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_recipes_with_path(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant recipes requests to recipes service (with additional path)"""
target_path = f"/api/v1/tenants/{tenant_id}/recipes/{path}".rstrip("/")
return await _proxy_to_recipes_service(request, target_path, tenant_id=tenant_id)
# ================================================================
# TENANT-SCOPED POS SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/pos/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_pos(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant POS requests to POS service"""
target_path = f"/api/v1/tenants/{tenant_id}/pos/{path}".rstrip("/")
return await _proxy_to_pos_service(request, target_path, tenant_id=tenant_id)
# ================================================================
# PROXY HELPER FUNCTIONS
# ================================================================
async def _proxy_to_tenant_service(request: Request, target_path: str):
"""Proxy request to tenant service"""
return await _proxy_request(request, target_path, settings.TENANT_SERVICE_URL)
async def _proxy_to_sales_service(request: Request, target_path: str):
"""Proxy request to sales service"""
return await _proxy_request(request, target_path, settings.SALES_SERVICE_URL)
async def _proxy_to_external_service(request: Request, target_path: str):
"""Proxy request to external service"""
return await _proxy_request(request, target_path, settings.EXTERNAL_SERVICE_URL)
async def _proxy_to_training_service(request: Request, target_path: str, tenant_id: str = None):
"""Proxy request to training service"""
return await _proxy_request(request, target_path, settings.TRAINING_SERVICE_URL, tenant_id=tenant_id)
async def _proxy_to_forecasting_service(request: Request, target_path: str, tenant_id: str = None):
"""Proxy request to forecasting service"""
return await _proxy_request(request, target_path, settings.FORECASTING_SERVICE_URL, tenant_id=tenant_id)
async def _proxy_to_notification_service(request: Request, target_path: str):
"""Proxy request to notification service"""
return await _proxy_request(request, target_path, settings.NOTIFICATION_SERVICE_URL)
async def _proxy_to_inventory_service(request: Request, target_path: str, tenant_id: str = None):
"""Proxy request to inventory service"""
return await _proxy_request(request, target_path, settings.INVENTORY_SERVICE_URL, tenant_id=tenant_id)
async def _proxy_to_production_service(request: Request, target_path: str, tenant_id: str = None):
"""Proxy request to production service"""
return await _proxy_request(request, target_path, settings.PRODUCTION_SERVICE_URL, tenant_id=tenant_id)
async def _proxy_to_orders_service(request: Request, target_path: str, tenant_id: str = None):
"""Proxy request to orders service"""
return await _proxy_request(request, target_path, settings.ORDERS_SERVICE_URL, tenant_id=tenant_id)
async def _proxy_to_suppliers_service(request: Request, target_path: str, tenant_id: str = None):
"""Proxy request to suppliers service"""
return await _proxy_request(request, target_path, settings.SUPPLIERS_SERVICE_URL, tenant_id=tenant_id)
async def _proxy_to_recipes_service(request: Request, target_path: str, tenant_id: str = None):
"""Proxy request to recipes service"""
return await _proxy_request(request, target_path, settings.RECIPES_SERVICE_URL, tenant_id=tenant_id)
async def _proxy_to_pos_service(request: Request, target_path: str, tenant_id: str = None):
"""Proxy request to POS service"""
return await _proxy_request(request, target_path, settings.POS_SERVICE_URL, tenant_id=tenant_id)
async def _proxy_to_procurement_service(request: Request, target_path: str, tenant_id: str = None):
"""Proxy request to procurement service"""
return await _proxy_request(request, target_path, settings.PROCUREMENT_SERVICE_URL, tenant_id=tenant_id)
async def _proxy_to_alert_processor_service(request: Request, target_path: str, tenant_id: str = None):
"""Proxy request to alert processor service"""
return await _proxy_request(request, target_path, settings.ALERT_PROCESSOR_SERVICE_URL, tenant_id=tenant_id)
async def _proxy_to_orchestrator_service(request: Request, target_path: str, tenant_id: str = None):
"""Proxy request to orchestrator service"""
return await _proxy_request(request, target_path, settings.ORCHESTRATOR_SERVICE_URL, tenant_id=tenant_id)
async def _proxy_to_ai_insights_service(request: Request, target_path: str, tenant_id: str = None):
"""Proxy request to AI insights service"""
return await _proxy_request(request, target_path, settings.AI_INSIGHTS_SERVICE_URL, tenant_id=tenant_id)
async def _proxy_to_distribution_service(request: Request, target_path: str, tenant_id: str = None):
"""Proxy request to distribution service"""
return await _proxy_request(request, target_path, settings.DISTRIBUTION_SERVICE_URL, tenant_id=tenant_id)
async def _proxy_request(request: Request, target_path: str, service_url: str, tenant_id: str = None):
"""Generic proxy function with enhanced error handling"""
# Handle OPTIONS requests directly for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400"
}
)
try:
url = f"{service_url}{target_path}"
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Add tenant ID header if provided (override if needed)
if tenant_id:
headers["x-tenant-id"] = tenant_id
# Debug logging
user_context = getattr(request.state, 'user', None)
if user_context:
logger.info(f"Forwarding request to {url} with user context: user_id={user_context.get('user_id')}, email={user_context.get('email')}, tenant_id={tenant_id}, subscription_tier={user_context.get('subscription_tier', 'not_set')}")
else:
logger.warning(f"No user context available when forwarding request to {url}. request.state.user: {getattr(request.state, 'user', 'NOT_SET')}")
# Get request body if present
body = None
files = None
data = None
if request.method in ["POST", "PUT", "PATCH"]:
content_type = request.headers.get("content-type", "")
logger.info(f"Processing {request.method} request with content-type: {content_type}")
# Handle multipart/form-data (file uploads)
if "multipart/form-data" in content_type:
logger.info("Detected multipart/form-data, parsing form...")
# For multipart/form-data, we need to re-parse and forward as files
form = await request.form()
logger.info(f"Form parsed, found {len(form)} fields: {list(form.keys())}")
# Extract files and form fields separately
files_dict = {}
data_dict = {}
for key, value in form.items():
if hasattr(value, 'file'): # It's a file
# Read file content
file_content = await value.read()
files_dict[key] = (value.filename, file_content, value.content_type)
logger.info(f"Found file field '{key}': filename={value.filename}, size={len(file_content)}, type={value.content_type}")
else: # It's a regular form field
data_dict[key] = value
logger.info(f"Found form field '{key}': value={value}")
files = files_dict if files_dict else None
data = data_dict if data_dict else None
logger.info(f"Forwarding multipart request with files={list(files.keys()) if files else None}, data={list(data.keys()) if data else None}")
# For multipart requests, we need to get fresh headers since httpx will set content-type
# Get all headers again to ensure we have the complete set
headers = header_manager.get_all_headers_for_proxy(request)
# httpx will automatically set content-type for multipart, so we don't need to remove it
else:
# For other content types, use body as before
body = await request.body()
logger.info(f"Using raw body, size: {len(body)} bytes")
# Add query parameters
params = dict(request.query_params)
timeout_config = httpx.Timeout(
connect=30.0, # Connection timeout
read=600.0, # Read timeout: 10 minutes (was 30s)
write=30.0, # Write timeout
pool=30.0 # Pool timeout
)
async with httpx.AsyncClient(timeout=timeout_config) as client:
response = await client.request(
method=request.method,
url=url,
headers=headers,
content=body,
files=files,
data=data,
params=params
)
# Handle different response types
if response.headers.get("content-type", "").startswith("application/json"):
try:
content = response.json()
except:
content = {"message": "Invalid JSON response from service"}
else:
content = response.text
return JSONResponse(
status_code=response.status_code,
content=content
)
except Exception as e:
logger.error(f"Unexpected error proxying to {service_url}{target_path}: {e}")
raise HTTPException(
status_code=500,
detail="Internal gateway error"
)

205
gateway/app/routes/user.py Normal file
View File

@@ -0,0 +1,205 @@
# ================================================================
# gateway/app/routes/user.py
# ================================================================
"""
Authentication routes for API Gateway
"""
import logging
import httpx
from fastapi import APIRouter, Request, Response, HTTPException, status
from fastapi.responses import JSONResponse
from typing import Dict, Any
import json
from app.core.config import settings
from app.core.header_manager import header_manager
from app.core.service_discovery import ServiceDiscovery
from shared.monitoring.metrics import MetricsCollector
logger = logging.getLogger(__name__)
router = APIRouter()
# Initialize service discovery and metrics
service_discovery = ServiceDiscovery()
metrics = MetricsCollector("gateway")
# Auth service configuration
AUTH_SERVICE_URL = settings.AUTH_SERVICE_URL or "http://auth-service:8000"
class UserProxy:
"""Authentication service proxy with enhanced error handling"""
def __init__(self):
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(30.0),
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20)
)
async def forward_request(
self,
method: str,
path: str,
request: Request
) -> Response:
"""Forward request to auth service with proper error handling"""
# Handle OPTIONS requests directly for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
}
)
try:
# Get auth service URL (with service discovery if available)
auth_url = await self._get_auth_service_url()
# FIX: Auth service uses /api/v1/auth/ prefix, not /api/v1/users/
target_url = f"{auth_url}/api/v1/auth/{path}"
# Prepare headers (remove hop-by-hop headers)
# IMPORTANT: Use request.headers directly to get headers added by middleware
# Also check request.state for headers injected by middleware
headers = self._prepare_headers(request.headers, request)
# Get request body
body = await request.body()
# Forward request
logger.info(f"Forwarding {method} {path} to auth service")
response = await self.client.request(
method=method,
url=target_url,
headers=headers,
content=body,
params=dict(request.query_params)
)
# Record metrics
metrics.increment_counter("gateway_auth_requests_total")
metrics.increment_counter(
"gateway_auth_responses_total",
labels={"status_code": str(response.status_code)}
)
# Prepare response headers
response_headers = self._prepare_response_headers(dict(response.headers))
return Response(
content=response.content,
status_code=response.status_code,
headers=response_headers,
media_type=response.headers.get("content-type")
)
except httpx.TimeoutException:
logger.error(f"Timeout forwarding {method} {path} to auth service")
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "timeout"})
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail="Authentication service timeout"
)
except httpx.ConnectError:
logger.error(f"Connection error forwarding {method} {path} to auth service")
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "connection"})
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Authentication service unavailable"
)
except Exception as e:
logger.error(f"Error forwarding {method} {path} to auth service: {e}")
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "unknown"})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal gateway error"
)
async def _get_auth_service_url(self) -> str:
"""Get auth service URL with service discovery"""
try:
# Try service discovery first
service_url = await service_discovery.get_service_url("auth-service")
if service_url:
return service_url
except Exception as e:
logger.warning(f"Service discovery failed: {e}")
# Fall back to configured URL
return AUTH_SERVICE_URL
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
"""Prepare headers for forwarding using unified HeaderManager"""
# Use unified HeaderManager to get all headers
if request:
all_headers = header_manager.get_all_headers_for_proxy(request)
else:
# Fallback: convert headers to dict manually
all_headers = {}
if hasattr(headers, '_list'):
for k, v in headers.__dict__.get('_list', []):
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
elif hasattr(headers, 'raw'):
for k, v in headers.raw:
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
else:
# Headers is already a dict
all_headers = dict(headers)
return all_headers
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
"""Prepare response headers"""
# Remove server-specific headers
filtered_headers = {
k: v for k, v in headers.items()
if k.lower() not in {'server', 'date'}
}
# Add CORS headers if needed
if settings.CORS_ORIGINS:
filtered_headers['Access-Control-Allow-Origin'] = '*'
filtered_headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
filtered_headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return filtered_headers
# Initialize proxy
user_proxy = UserProxy()
# ================================================================
# USER MANAGEMENT ENDPOINTS - Proxied to auth service
# ================================================================
@router.get("/delete/{user_id}/deletion-preview")
async def preview_user_deletion(user_id: str, request: Request):
"""Proxy user deletion preview to auth service"""
return await user_proxy.forward_request("GET", f"delete/{user_id}/deletion-preview", request)
@router.delete("/delete/{user_id}")
async def delete_user(user_id: str, request: Request):
"""Proxy admin user deletion to auth service"""
return await user_proxy.forward_request("DELETE", f"delete/{user_id}", request)
# ================================================================
# CATCH-ALL ROUTE for any other user endpoints
# ================================================================
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
async def proxy_auth_requests(path: str, request: Request):
"""Catch-all proxy for auth requests"""
return await user_proxy.forward_request(request.method, path, request)

View File

@@ -0,0 +1,116 @@
"""
Webhook routes for API Gateway - Handles webhook endpoints
Route Configuration Notes:
- Stripe configures webhook URL as: https://domain.com/api/v1/webhooks/stripe
- Gateway receives /api/v1/webhooks/* routes and proxies to tenant service at /webhooks/*
- Gateway routes use /api/v1 prefix, but tenant service routes use /webhooks/* prefix
"""
from fastapi import APIRouter, Request, Response, HTTPException
from fastapi.responses import JSONResponse
import httpx
import logging
from typing import Optional
from app.core.config import settings
from app.core.header_manager import header_manager
logger = logging.getLogger(__name__)
router = APIRouter()
# ================================================================
# WEBHOOK ENDPOINTS - Direct routing to tenant service
# All routes use /api/v1 prefix for consistency
# ================================================================
# Stripe webhook endpoint
@router.post("/api/v1/webhooks/stripe")
async def proxy_stripe_webhook(request: Request):
"""Proxy Stripe webhook requests to tenant service (path: /webhooks/stripe)"""
logger.info("Received Stripe webhook at /api/v1/webhooks/stripe")
return await _proxy_to_tenant_service(request, "/webhooks/stripe")
# Generic webhook endpoint
@router.post("/api/v1/webhooks/generic")
async def proxy_generic_webhook(request: Request):
"""Proxy generic webhook requests to tenant service (path: /webhooks/generic)"""
return await _proxy_to_tenant_service(request, "/webhooks/generic")
# ================================================================
# PROXY HELPER FUNCTIONS
# ================================================================
async def _proxy_to_tenant_service(request: Request, target_path: str):
"""Proxy request to tenant service"""
return await _proxy_request(request, target_path, settings.TENANT_SERVICE_URL)
async def _proxy_request(request: Request, target_path: str, service_url: str):
"""Generic proxy function with enhanced error handling"""
# Handle OPTIONS requests directly for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID, Stripe-Signature",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400"
}
)
try:
url = f"{service_url}{target_path}"
# Use unified HeaderManager for consistent header forwarding
headers = header_manager.get_all_headers_for_proxy(request)
# Debug logging
logger.info(f"Forwarding webhook request to {url}")
# Get request body if present
body = None
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
# Add query parameters
params = dict(request.query_params)
timeout_config = httpx.Timeout(
connect=30.0,
read=60.0,
write=30.0,
pool=30.0
)
async with httpx.AsyncClient(timeout=timeout_config) as client:
response = await client.request(
method=request.method,
url=url,
headers=headers,
content=body,
params=params
)
# Handle different response types
if response.headers.get("content-type", "").startswith("application/json"):
try:
content = response.json()
except:
content = {"message": "Invalid JSON response from service"}
else:
content = response.text
return JSONResponse(
status_code=response.status_code,
content=content
)
except Exception as e:
logger.error(f"Unexpected error proxying webhook request to {service_url}{target_path}: {e}")
raise HTTPException(
status_code=500,
detail="Internal gateway error"
)

View File

@@ -0,0 +1,331 @@
"""
Enhanced Subscription Error Responses
Provides detailed, conversion-optimized error responses when users
hit subscription tier restrictions (HTTP 402 Payment Required).
"""
from typing import List, Dict, Optional, Any
from pydantic import BaseModel
class UpgradeBenefit(BaseModel):
"""A single benefit of upgrading"""
text: str
icon: str # Icon name (e.g., 'zap', 'trending-up', 'shield')
class ROIEstimate(BaseModel):
"""ROI estimate for upgrade"""
monthly_savings_min: int
monthly_savings_max: int
currency: str = ""
payback_period_days: int
class FeatureRestrictionDetail(BaseModel):
"""Detailed error response for feature restrictions"""
error: str = "subscription_tier_insufficient"
code: str = "SUBSCRIPTION_UPGRADE_REQUIRED"
status_code: int = 402
message: str
details: Dict[str, Any]
# Feature-specific upgrade messages
FEATURE_MESSAGES = {
'analytics': {
'title': 'Unlock Advanced Analytics',
'description': 'Get deeper insights into your bakery performance with advanced analytics dashboards.',
'benefits': [
UpgradeBenefit(text='90-day forecast horizon (vs 7 days)', icon='calendar'),
UpgradeBenefit(text='Weather & traffic integration', icon='cloud'),
UpgradeBenefit(text='What-if scenario modeling', icon='trending-up'),
UpgradeBenefit(text='Custom reports & dashboards', icon='bar-chart'),
UpgradeBenefit(text='Profitability analysis by product', icon='dollar-sign')
],
'roi': ROIEstimate(
monthly_savings_min=800,
monthly_savings_max=1200,
payback_period_days=7
)
},
'multi_location': {
'title': 'Scale to Multiple Locations',
'description': 'Manage up to 3 bakery locations with centralized inventory and analytics.',
'benefits': [
UpgradeBenefit(text='Up to 3 locations (vs 1)', icon='map-pin'),
UpgradeBenefit(text='Inventory transfer between locations', icon='arrow-right'),
UpgradeBenefit(text='Location comparison analytics', icon='bar-chart'),
UpgradeBenefit(text='Centralized reporting', icon='file-text'),
UpgradeBenefit(text='500 products (vs 50)', icon='package')
],
'roi': ROIEstimate(
monthly_savings_min=1000,
monthly_savings_max=2000,
payback_period_days=10
)
},
'pos_integration': {
'title': 'Integrate Your POS System',
'description': 'Automatically sync sales data from your point-of-sale system.',
'benefits': [
UpgradeBenefit(text='Automatic sales import', icon='refresh-cw'),
UpgradeBenefit(text='Real-time inventory sync', icon='zap'),
UpgradeBenefit(text='Save 10+ hours/week on data entry', icon='clock'),
UpgradeBenefit(text='Eliminate manual errors', icon='check-circle'),
UpgradeBenefit(text='Faster, more accurate forecasts', icon='trending-up')
],
'roi': ROIEstimate(
monthly_savings_min=600,
monthly_savings_max=1000,
payback_period_days=5
)
},
'advanced_forecasting': {
'title': 'Unlock Advanced AI Forecasting',
'description': 'Get more accurate predictions with weather, traffic, and seasonal patterns.',
'benefits': [
UpgradeBenefit(text='Weather-based demand predictions', icon='cloud'),
UpgradeBenefit(text='Traffic & event impact analysis', icon='activity'),
UpgradeBenefit(text='Seasonal pattern detection', icon='calendar'),
UpgradeBenefit(text='15% more accurate forecasts', icon='target'),
UpgradeBenefit(text='Reduce waste by 7+ percentage points', icon='trending-down')
],
'roi': ROIEstimate(
monthly_savings_min=800,
monthly_savings_max=1500,
payback_period_days=7
)
},
'scenario_modeling': {
'title': 'Plan with What-If Scenarios',
'description': 'Model different business scenarios before making decisions.',
'benefits': [
UpgradeBenefit(text='Test menu changes before launch', icon='beaker'),
UpgradeBenefit(text='Optimize pricing strategies', icon='dollar-sign'),
UpgradeBenefit(text='Plan seasonal inventory', icon='calendar'),
UpgradeBenefit(text='Risk assessment tools', icon='shield'),
UpgradeBenefit(text='Data-driven decision making', icon='trending-up')
],
'roi': ROIEstimate(
monthly_savings_min=500,
monthly_savings_max=1000,
payback_period_days=10
)
},
'api_access': {
'title': 'Integrate with Your Tools',
'description': 'Connect bakery.ai with your existing business systems via API.',
'benefits': [
UpgradeBenefit(text='Full REST API access', icon='code'),
UpgradeBenefit(text='1,000 API calls/hour (vs 100)', icon='zap'),
UpgradeBenefit(text='Webhook support for real-time events', icon='bell'),
UpgradeBenefit(text='Custom integrations', icon='link'),
UpgradeBenefit(text='API documentation & support', icon='book')
],
'roi': None # ROI varies by use case
}
}
def create_upgrade_required_response(
feature: str,
current_tier: str,
required_tier: str = 'professional',
allowed_tiers: Optional[List[str]] = None,
custom_message: Optional[str] = None
) -> FeatureRestrictionDetail:
"""
Create an enhanced 402 error response with upgrade suggestions
Args:
feature: Feature key (e.g., 'analytics', 'multi_location')
current_tier: User's current subscription tier
required_tier: Minimum tier required for this feature
allowed_tiers: List of tiers that have access (defaults to [required_tier, 'enterprise'])
custom_message: Optional custom message (overrides default)
Returns:
FeatureRestrictionDetail with upgrade information
"""
if allowed_tiers is None:
allowed_tiers = [required_tier, 'enterprise'] if required_tier != 'enterprise' else ['enterprise']
# Get feature-specific messaging
feature_info = FEATURE_MESSAGES.get(feature, {
'title': f'Upgrade to {required_tier.capitalize()}',
'description': f'This feature requires a {required_tier.capitalize()} subscription.',
'benefits': [],
'roi': None
})
# Build detailed response
message = custom_message or feature_info['title']
details = {
'required_feature': feature,
'minimum_tier': required_tier,
'allowed_tiers': allowed_tiers,
'current_tier': current_tier,
# Upgrade messaging
'title': feature_info['title'],
'description': feature_info['description'],
'benefits': [b.dict() for b in feature_info['benefits']],
# ROI information
'roi_estimate': feature_info['roi'].dict() if feature_info['roi'] else None,
# Call-to-action
'upgrade_url': f'/app/settings/subscription?upgrade={required_tier}&from={current_tier}&feature={feature}',
'preview_url': f'/app/{feature}?demo=true' if feature in ['analytics'] else None,
# Suggested tier
'suggested_tier': required_tier,
'suggested_tier_display': required_tier.capitalize(),
# Additional context
'can_preview': feature in ['analytics'],
'has_free_trial': True,
'trial_days': 0,
# Social proof
'social_proof': get_social_proof_message(required_tier),
# Pricing context
'pricing_context': get_pricing_context(required_tier)
}
return FeatureRestrictionDetail(
message=message,
details=details
)
def create_quota_exceeded_response(
metric: str,
current: int,
limit: int,
current_tier: str,
upgrade_tier: str = 'professional',
upgrade_limit: Optional[int] = None,
reset_at: Optional[str] = None
) -> Dict[str, Any]:
"""
Create an enhanced 429 error response for quota limits
Args:
metric: The quota metric (e.g., 'training_jobs', 'forecasts')
current: Current usage
limit: Quota limit
current_tier: User's current subscription tier
upgrade_tier: Suggested upgrade tier
upgrade_limit: Limit in upgraded tier (None = unlimited)
reset_at: When the quota resets (ISO datetime string)
Returns:
Error response with upgrade suggestions
"""
metric_labels = {
'training_jobs': 'Training Jobs',
'forecasts': 'Forecasts',
'api_calls': 'API Calls',
'products': 'Products',
'users': 'Users',
'locations': 'Locations'
}
label = metric_labels.get(metric, metric.replace('_', ' ').title())
return {
'error': 'quota_exceeded',
'code': 'QUOTA_LIMIT_REACHED',
'status_code': 429,
'message': f'Daily quota exceeded for {label.lower()}',
'details': {
'metric': metric,
'label': label,
'current': current,
'limit': limit,
'reset_at': reset_at,
'quota_type': metric,
# Upgrade suggestion
'can_upgrade': True,
'upgrade_tier': upgrade_tier,
'upgrade_limit': upgrade_limit,
'upgrade_benefit': f'{upgrade_limit}x more capacity' if upgrade_limit and limit else 'Unlimited capacity',
# Call-to-action
'upgrade_url': f'/app/settings/subscription?upgrade={upgrade_tier}&from={current_tier}&reason=quota_exceeded&metric={metric}',
# ROI context
'roi_message': get_quota_roi_message(metric, current_tier, upgrade_tier)
}
}
def get_social_proof_message(tier: str) -> str:
"""Get social proof message for a tier"""
messages = {
'professional': '87% of growing bakeries choose Professional',
'enterprise': 'Trusted by multi-location bakery chains across Europe'
}
return messages.get(tier, '')
def get_pricing_context(tier: str) -> Dict[str, Any]:
"""Get pricing context for a tier"""
pricing = {
'professional': {
'monthly_price': 149,
'yearly_price': 1490,
'per_day_cost': 4.97,
'currency': '',
'savings_yearly': 596,
'value_message': 'Only €4.97/day for unlimited growth'
},
'enterprise': {
'monthly_price': 499,
'yearly_price': 4990,
'per_day_cost': 16.63,
'currency': '',
'savings_yearly': 1998,
'value_message': 'Complete solution for €16.63/day'
}
}
return pricing.get(tier, {})
def get_quota_roi_message(metric: str, current_tier: str, upgrade_tier: str) -> str:
"""Get ROI-focused message for quota upgrades"""
messages = {
'training_jobs': 'More training = better predictions = less waste',
'forecasts': 'Run forecasts for all products daily to optimize inventory',
'products': 'Expand your menu without limits',
'users': 'Give your entire team access to real-time data',
'locations': 'Manage all your bakeries from one platform'
}
return messages.get(metric, 'Unlock more capacity to grow your business')
# Example usage function for gateway middleware
def handle_feature_restriction(
feature: str,
current_tier: str,
required_tier: str = 'professional'
) -> tuple[int, Dict[str, Any]]:
"""
Handle feature restriction in gateway middleware
Returns:
(status_code, response_body)
"""
response = create_upgrade_required_response(
feature=feature,
current_tier=current_tier,
required_tier=required_tier
)
return response.status_code, response.dict()