Initial commit - production deployment
This commit is contained in:
0
gateway/app/__init__.py
Normal file
0
gateway/app/__init__.py
Normal file
0
gateway/app/core/__init__.py
Normal file
0
gateway/app/core/__init__.py
Normal file
46
gateway/app/core/config.py
Normal file
46
gateway/app/core/config.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# ================================================================
|
||||
# GATEWAY SERVICE CONFIGURATION
|
||||
# gateway/app/core/config.py
|
||||
# ================================================================
|
||||
|
||||
"""
|
||||
Gateway service configuration
|
||||
Central API Gateway for all microservices
|
||||
"""
|
||||
|
||||
from shared.config.base import BaseServiceSettings
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
class GatewaySettings(BaseServiceSettings):
|
||||
"""Gateway-specific settings"""
|
||||
|
||||
# Service Identity
|
||||
APP_NAME: str = "Bakery Forecasting Gateway"
|
||||
SERVICE_NAME: str = "gateway"
|
||||
DESCRIPTION: str = "API Gateway for Bakery Forecasting Platform"
|
||||
|
||||
# Gateway-specific Redis database
|
||||
REDIS_DB: int = 6
|
||||
|
||||
# Service Discovery
|
||||
CONSUL_URL: str = os.getenv("CONSUL_URL", "http://consul:8500")
|
||||
ENABLE_SERVICE_DISCOVERY: bool = os.getenv("ENABLE_SERVICE_DISCOVERY", "false").lower() == "true"
|
||||
|
||||
# Load Balancing
|
||||
ENABLE_LOAD_BALANCING: bool = os.getenv("ENABLE_LOAD_BALANCING", "true").lower() == "true"
|
||||
LOAD_BALANCER_ALGORITHM: str = os.getenv("LOAD_BALANCER_ALGORITHM", "round_robin")
|
||||
|
||||
# Circuit Breaker
|
||||
CIRCUIT_BREAKER_ENABLED: bool = os.getenv("CIRCUIT_BREAKER_ENABLED", "true").lower() == "true"
|
||||
CIRCUIT_BREAKER_FAILURE_THRESHOLD: int = int(os.getenv("CIRCUIT_BREAKER_FAILURE_THRESHOLD", "5"))
|
||||
CIRCUIT_BREAKER_RECOVERY_TIMEOUT: int = int(os.getenv("CIRCUIT_BREAKER_RECOVERY_TIMEOUT", "60"))
|
||||
|
||||
# Request/Response Settings
|
||||
MAX_REQUEST_SIZE: int = int(os.getenv("MAX_REQUEST_SIZE", "10485760")) # 10MB
|
||||
REQUEST_TIMEOUT: int = int(os.getenv("REQUEST_TIMEOUT", "30"))
|
||||
|
||||
# Gateway doesn't need a database
|
||||
DATABASE_URL: str = ""
|
||||
|
||||
settings = GatewaySettings()
|
||||
346
gateway/app/core/header_manager.py
Normal file
346
gateway/app/core/header_manager.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
Unified Header Management System for API Gateway
|
||||
Centralized header injection, forwarding, and validation
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import Request
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class HeaderManager:
|
||||
"""
|
||||
Centralized header management for consistent header handling across gateway
|
||||
"""
|
||||
|
||||
# Standard header names (lowercase for consistency)
|
||||
STANDARD_HEADERS = {
|
||||
'user_id': 'x-user-id',
|
||||
'user_email': 'x-user-email',
|
||||
'user_role': 'x-user-role',
|
||||
'user_type': 'x-user-type',
|
||||
'service_name': 'x-service-name',
|
||||
'tenant_id': 'x-tenant-id',
|
||||
'subscription_tier': 'x-subscription-tier',
|
||||
'subscription_status': 'x-subscription-status',
|
||||
'is_demo': 'x-is-demo',
|
||||
'demo_session_id': 'x-demo-session-id',
|
||||
'demo_account_type': 'x-demo-account-type',
|
||||
'tenant_access_type': 'x-tenant-access-type',
|
||||
'can_view_children': 'x-can-view-children',
|
||||
'parent_tenant_id': 'x-parent-tenant-id',
|
||||
'forwarded_by': 'x-forwarded-by',
|
||||
'request_id': 'x-request-id'
|
||||
}
|
||||
|
||||
# Headers that should be sanitized/removed from incoming requests
|
||||
SANITIZED_HEADERS = [
|
||||
'x-subscription-',
|
||||
'x-user-',
|
||||
'x-tenant-',
|
||||
'x-demo-',
|
||||
'x-forwarded-by'
|
||||
]
|
||||
|
||||
# Headers that should be forwarded to downstream services
|
||||
FORWARDABLE_HEADERS = [
|
||||
'authorization',
|
||||
'content-type',
|
||||
'accept',
|
||||
'accept-language',
|
||||
'user-agent',
|
||||
'x-internal-service', # Required for internal service-to-service ML/alert triggers
|
||||
'stripe-signature' # Required for Stripe webhook signature verification
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._initialized = False
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize header manager"""
|
||||
if not self._initialized:
|
||||
logger.info("HeaderManager initialized")
|
||||
self._initialized = True
|
||||
|
||||
def sanitize_incoming_headers(self, request: Request) -> None:
|
||||
"""
|
||||
Remove sensitive headers from incoming request to prevent spoofing
|
||||
"""
|
||||
if not hasattr(request.headers, '_list'):
|
||||
return
|
||||
|
||||
# Filter out headers that start with sanitized prefixes
|
||||
sanitized_headers = [
|
||||
(k, v) for k, v in request.headers.raw
|
||||
if not any(k.decode().lower().startswith(prefix.lower())
|
||||
for prefix in self.SANITIZED_HEADERS)
|
||||
]
|
||||
|
||||
request.headers.__dict__["_list"] = sanitized_headers
|
||||
logger.debug("Sanitized incoming headers")
|
||||
|
||||
def inject_context_headers(self, request: Request, user_context: Dict[str, Any],
|
||||
tenant_id: Optional[str] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Inject standardized context headers into request
|
||||
Returns dict of injected headers for reference
|
||||
"""
|
||||
injected_headers = {}
|
||||
|
||||
# Ensure headers list exists
|
||||
if not hasattr(request.headers, '_list'):
|
||||
request.headers.__dict__["_list"] = []
|
||||
|
||||
# Store headers in request.state for cross-middleware access
|
||||
request.state.injected_headers = {}
|
||||
|
||||
# User context headers
|
||||
if user_context.get('user_id'):
|
||||
header_name = self.STANDARD_HEADERS['user_id']
|
||||
header_value = str(user_context['user_id'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('email'):
|
||||
header_name = self.STANDARD_HEADERS['user_email']
|
||||
header_value = str(user_context['email'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('role'):
|
||||
header_name = self.STANDARD_HEADERS['user_role']
|
||||
header_value = str(user_context['role'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# User type (service vs regular user)
|
||||
if user_context.get('type'):
|
||||
header_name = self.STANDARD_HEADERS['user_type']
|
||||
header_value = str(user_context['type'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Service name for service tokens
|
||||
if user_context.get('service'):
|
||||
header_name = self.STANDARD_HEADERS['service_name']
|
||||
header_value = str(user_context['service'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Tenant context
|
||||
if tenant_id:
|
||||
header_name = self.STANDARD_HEADERS['tenant_id']
|
||||
header_value = str(tenant_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Subscription context
|
||||
if user_context.get('subscription_tier'):
|
||||
header_name = self.STANDARD_HEADERS['subscription_tier']
|
||||
header_value = str(user_context['subscription_tier'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('subscription_status'):
|
||||
header_name = self.STANDARD_HEADERS['subscription_status']
|
||||
header_value = str(user_context['subscription_status'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Demo session context
|
||||
is_demo = user_context.get('is_demo', False)
|
||||
if is_demo:
|
||||
header_name = self.STANDARD_HEADERS['is_demo']
|
||||
header_value = "true"
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('demo_session_id'):
|
||||
header_name = self.STANDARD_HEADERS['demo_session_id']
|
||||
header_value = str(user_context['demo_session_id'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('demo_account_type'):
|
||||
header_name = self.STANDARD_HEADERS['demo_account_type']
|
||||
header_value = str(user_context['demo_account_type'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Hierarchical access context
|
||||
if tenant_id:
|
||||
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
|
||||
can_view_children = getattr(request.state, 'can_view_children', False)
|
||||
|
||||
header_name = self.STANDARD_HEADERS['tenant_access_type']
|
||||
header_value = str(tenant_access_type)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
header_name = self.STANDARD_HEADERS['can_view_children']
|
||||
header_value = str(can_view_children).lower()
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Parent tenant ID if hierarchical access
|
||||
parent_tenant_id = getattr(request.state, 'parent_tenant_id', None)
|
||||
if parent_tenant_id:
|
||||
header_name = self.STANDARD_HEADERS['parent_tenant_id']
|
||||
header_value = str(parent_tenant_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Gateway identification
|
||||
header_name = self.STANDARD_HEADERS['forwarded_by']
|
||||
header_value = "bakery-gateway"
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Request ID if available
|
||||
request_id = getattr(request.state, 'request_id', None)
|
||||
if request_id:
|
||||
header_name = self.STANDARD_HEADERS['request_id']
|
||||
header_value = str(request_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
logger.info("🔧 Injected context headers",
|
||||
user_id=user_context.get('user_id'),
|
||||
user_type=user_context.get('type', ''),
|
||||
service_name=user_context.get('service', ''),
|
||||
role=user_context.get('role', ''),
|
||||
tenant_id=tenant_id,
|
||||
is_demo=is_demo,
|
||||
demo_session_id=user_context.get('demo_session_id', ''),
|
||||
path=request.url.path)
|
||||
|
||||
return injected_headers
|
||||
|
||||
def _add_header(self, request: Request, header_name: str, header_value: str) -> None:
|
||||
"""
|
||||
Safely add header to request
|
||||
"""
|
||||
try:
|
||||
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add header {header_name}: {e}")
|
||||
|
||||
def get_forwardable_headers(self, request: Request) -> Dict[str, str]:
|
||||
"""
|
||||
Get headers that should be forwarded to downstream services
|
||||
Includes both original request headers and injected context headers
|
||||
"""
|
||||
forwardable_headers = {}
|
||||
|
||||
# Add forwardable original headers
|
||||
for header_name in self.FORWARDABLE_HEADERS:
|
||||
header_value = request.headers.get(header_name)
|
||||
if header_value:
|
||||
forwardable_headers[header_name] = header_value
|
||||
|
||||
# Add injected context headers from request.state
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
for header_name, header_value in request.state.injected_headers.items():
|
||||
forwardable_headers[header_name] = header_value
|
||||
|
||||
# Add authorization header if present
|
||||
auth_header = request.headers.get('authorization')
|
||||
if auth_header:
|
||||
forwardable_headers['authorization'] = auth_header
|
||||
|
||||
return forwardable_headers
|
||||
|
||||
def get_all_headers_for_proxy(self, request: Request,
|
||||
additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Get complete set of headers for proxying to downstream services
|
||||
"""
|
||||
headers = self.get_forwardable_headers(request)
|
||||
|
||||
# Add any additional headers
|
||||
if additional_headers:
|
||||
headers.update(additional_headers)
|
||||
|
||||
# Remove host header as it will be set by httpx
|
||||
headers.pop('host', None)
|
||||
|
||||
return headers
|
||||
|
||||
def validate_required_headers(self, request: Request, required_headers: List[str]) -> bool:
|
||||
"""
|
||||
Validate that required headers are present
|
||||
"""
|
||||
missing_headers = []
|
||||
|
||||
for header_name in required_headers:
|
||||
# Check in injected headers first
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
if header_name in request.state.injected_headers:
|
||||
continue
|
||||
|
||||
# Check in request headers
|
||||
if request.headers.get(header_name):
|
||||
continue
|
||||
|
||||
missing_headers.append(header_name)
|
||||
|
||||
if missing_headers:
|
||||
logger.warning(f"Missing required headers: {missing_headers}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_header_value(self, request: Request, header_name: str,
|
||||
default: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Get header value from either injected headers or request headers
|
||||
"""
|
||||
# Check injected headers first
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
if header_name in request.state.injected_headers:
|
||||
return request.state.injected_headers[header_name]
|
||||
|
||||
# Check request headers
|
||||
return request.headers.get(header_name, default)
|
||||
|
||||
def add_header_for_middleware(self, request: Request, header_name: str, header_value: str) -> None:
|
||||
"""
|
||||
Allow middleware to add headers to the unified header system
|
||||
This ensures all headers are available for proxying
|
||||
"""
|
||||
# Ensure injected_headers exists
|
||||
if not hasattr(request.state, 'injected_headers'):
|
||||
request.state.injected_headers = {}
|
||||
|
||||
# Add header to injected_headers
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Also add to actual request headers for compatibility
|
||||
try:
|
||||
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add header {header_name} to request headers: {e}")
|
||||
|
||||
logger.debug(f"Middleware added header: {header_name} = {header_value}")
|
||||
|
||||
|
||||
# Global instance for easy access
|
||||
header_manager = HeaderManager()
|
||||
65
gateway/app/core/service_discovery.py
Normal file
65
gateway/app/core/service_discovery.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
Service discovery for API Gateway
|
||||
"""
|
||||
|
||||
import logging
|
||||
import httpx
|
||||
from typing import Optional, Dict
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ServiceDiscovery:
|
||||
"""Service discovery client"""
|
||||
|
||||
def __init__(self):
|
||||
self.consul_url = settings.CONSUL_URL if hasattr(settings, 'CONSUL_URL') else None
|
||||
self.service_cache: Dict[str, str] = {}
|
||||
|
||||
async def get_service_url(self, service_name: str) -> Optional[str]:
|
||||
"""Get service URL from service discovery"""
|
||||
|
||||
# Return cached URL if available
|
||||
if service_name in self.service_cache:
|
||||
return self.service_cache[service_name]
|
||||
|
||||
# Try Consul if enabled
|
||||
if self.consul_url and getattr(settings, 'ENABLE_SERVICE_DISCOVERY', False):
|
||||
try:
|
||||
url = await self._get_from_consul(service_name)
|
||||
if url:
|
||||
self.service_cache[service_name] = url
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get {service_name} from Consul: {e}")
|
||||
|
||||
# Fall back to environment variables
|
||||
return self._get_from_env(service_name)
|
||||
|
||||
async def _get_from_consul(self, service_name: str) -> Optional[str]:
|
||||
"""Get service URL from Consul"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.consul_url}/v1/health/service/{service_name}?passing=true"
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
services = response.json()
|
||||
if services:
|
||||
service = services[0]
|
||||
address = service['Service']['Address']
|
||||
port = service['Service']['Port']
|
||||
return f"http://{address}:{port}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Consul query failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _get_from_env(self, service_name: str) -> Optional[str]:
|
||||
"""Get service URL from environment variables"""
|
||||
env_var = f"{service_name.upper().replace('-', '_')}_SERVICE_URL"
|
||||
return getattr(settings, env_var, None)
|
||||
512
gateway/app/main.py
Normal file
512
gateway/app/main.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""
|
||||
API Gateway - Central entry point for all microservices
|
||||
Handles routing, authentication, rate limiting, and cross-cutting concerns
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import structlog
|
||||
import resource
|
||||
import os
|
||||
import time
|
||||
from fastapi import Request, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import StreamingResponse
|
||||
import httpx
|
||||
from shared.redis_utils import initialize_redis, close_redis, get_redis_client
|
||||
from shared.service_base import StandardFastAPIService
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from app.middleware.request_id import RequestIDMiddleware
|
||||
from app.middleware.auth import AuthMiddleware
|
||||
from app.middleware.logging import LoggingMiddleware
|
||||
from app.middleware.rate_limit import RateLimitMiddleware
|
||||
from app.middleware.rate_limiting import APIRateLimitMiddleware
|
||||
from app.middleware.subscription import SubscriptionMiddleware
|
||||
from app.middleware.demo_middleware import DemoMiddleware
|
||||
from app.middleware.read_only_mode import ReadOnlyModeMiddleware
|
||||
from app.routes import auth, tenant, registration, nominatim, subscription, demo, pos, geocoding, poi_context, webhooks, telemetry
|
||||
|
||||
# Initialize logger
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Check file descriptor limits
|
||||
try:
|
||||
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
if soft_limit < 1024:
|
||||
logger.warning(f"Low file descriptor limit detected: {soft_limit}")
|
||||
else:
|
||||
logger.info(f"File descriptor limit: {soft_limit} (sufficient)")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not check file descriptor limits: {e}")
|
||||
|
||||
# Global Redis client for SSE streaming
|
||||
redis_client = None
|
||||
|
||||
|
||||
class GatewayService(StandardFastAPIService):
|
||||
"""Gateway Service with standardized monitoring setup"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Initialize HeaderManager early
|
||||
header_manager.initialize()
|
||||
logger.info("HeaderManager initialized")
|
||||
|
||||
# Initialize Redis during service creation so it's available when needed
|
||||
try:
|
||||
# We need to run the async initialization in a sync context
|
||||
import asyncio
|
||||
try:
|
||||
# Check if there's already a running event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If there is, we'll initialize Redis later in on_startup
|
||||
self.redis_initialized = False
|
||||
self.redis_client = None
|
||||
except RuntimeError:
|
||||
# No event loop running, safe to run the async function
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply() # Allow nested event loops
|
||||
|
||||
async def init_redis():
|
||||
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
|
||||
return await get_redis_client()
|
||||
|
||||
self.redis_client = asyncio.run(init_redis())
|
||||
self.redis_initialized = True
|
||||
logger.info("Connected to Redis for SSE streaming")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Redis during service creation: {e}")
|
||||
self.redis_initialized = False
|
||||
self.redis_client = None
|
||||
|
||||
async def on_startup(self, app):
|
||||
"""Custom startup logic for Gateway"""
|
||||
global redis_client
|
||||
|
||||
# Initialize Redis if not already done during service creation
|
||||
if not self.redis_initialized:
|
||||
try:
|
||||
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
|
||||
self.redis_client = await get_redis_client()
|
||||
redis_client = self.redis_client # Update global variable
|
||||
self.redis_initialized = True
|
||||
logger.info("Connected to Redis for SSE streaming")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis during startup: {e}")
|
||||
|
||||
# Register custom metrics for gateway-specific operations
|
||||
if self.telemetry_providers and self.telemetry_providers.app_metrics:
|
||||
logger.info("Gateway-specific metrics tracking enabled")
|
||||
|
||||
await super().on_startup(app)
|
||||
|
||||
async def on_shutdown(self, app):
|
||||
"""Custom shutdown logic for Gateway"""
|
||||
await super().on_shutdown(app)
|
||||
|
||||
# Close Redis
|
||||
await close_redis()
|
||||
logger.info("Redis connection closed")
|
||||
|
||||
|
||||
# Create service instance
|
||||
service = GatewayService(
|
||||
service_name="gateway",
|
||||
app_name="Bakery Forecasting API Gateway",
|
||||
description="Central API Gateway for bakery forecasting microservices",
|
||||
version="1.0.0",
|
||||
log_level=getattr(settings, 'LOG_LEVEL', 'INFO'),
|
||||
cors_origins=settings.CORS_ORIGINS_LIST,
|
||||
enable_metrics=True,
|
||||
enable_health_checks=True,
|
||||
enable_tracing=True,
|
||||
enable_cors=True
|
||||
)
|
||||
|
||||
# Create FastAPI app
|
||||
app = service.create_app()
|
||||
|
||||
# Add API rate limiting middleware with Redis client - this needs to be done after app creation
|
||||
# but before other middleware that might depend on it
|
||||
# Wait for Redis to be initialized if not already done
|
||||
if not hasattr(service, 'redis_client') or not service.redis_client:
|
||||
# Wait briefly for Redis initialization to complete
|
||||
import time
|
||||
time.sleep(1)
|
||||
# Check again after allowing time for initialization
|
||||
if hasattr(service, 'redis_client') and service.redis_client:
|
||||
app.add_middleware(APIRateLimitMiddleware, redis_client=service.redis_client)
|
||||
logger.info("API rate limiting middleware enabled")
|
||||
else:
|
||||
logger.warning("Redis client not available for API rate limiting middleware")
|
||||
else:
|
||||
app.add_middleware(APIRateLimitMiddleware, redis_client=service.redis_client)
|
||||
logger.info("API rate limiting middleware enabled")
|
||||
|
||||
# Add gateway-specific middleware (in REVERSE order of execution)
|
||||
# Execution order: RequestIDMiddleware -> DemoMiddleware -> AuthMiddleware -> ReadOnlyModeMiddleware -> SubscriptionMiddleware -> APIRateLimitMiddleware -> RateLimitMiddleware -> LoggingMiddleware
|
||||
app.add_middleware(LoggingMiddleware)
|
||||
app.add_middleware(RateLimitMiddleware, calls_per_minute=300)
|
||||
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL)
|
||||
app.add_middleware(ReadOnlyModeMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL)
|
||||
app.add_middleware(AuthMiddleware)
|
||||
app.add_middleware(DemoMiddleware)
|
||||
app.add_middleware(RequestIDMiddleware)
|
||||
|
||||
# Include routers
|
||||
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
|
||||
app.include_router(registration.router, prefix="/api/v1", tags=["registration"])
|
||||
app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"])
|
||||
app.include_router(subscription.router, prefix="/api/v1", tags=["subscriptions"])
|
||||
# Notification routes are now handled by tenant router at /api/v1/tenants/{tenant_id}/notifications/*
|
||||
app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"])
|
||||
app.include_router(geocoding.router, prefix="/api/v1/geocoding", tags=["geocoding"])
|
||||
app.include_router(pos.router, prefix="/api/v1/pos", tags=["pos"])
|
||||
app.include_router(demo.router, prefix="/api/v1", tags=["demo"])
|
||||
# Include webhooks at the root level to handle /api/v1/webhooks/*
|
||||
# Webhook routes are defined with full /api/v1/webhooks/* paths for consistency
|
||||
app.include_router(webhooks.router, prefix="", tags=["webhooks"])
|
||||
|
||||
# Include telemetry routes for frontend OpenTelemetry data
|
||||
app.include_router(telemetry.router, prefix="/api/v1", tags=["telemetry"])
|
||||
|
||||
|
||||
# ================================================================
|
||||
# SERVER-SENT EVENTS (SSE) HELPER FUNCTIONS
|
||||
# ================================================================
|
||||
|
||||
def _get_subscription_channels(tenant_id: str, channel_filters: list) -> list:
|
||||
"""Determine which Redis channels to subscribe to based on filters"""
|
||||
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
|
||||
all_classes = ["alerts", "notifications"]
|
||||
channels = []
|
||||
|
||||
if not channel_filters:
|
||||
# Subscribe to ALL channels (backward compatible)
|
||||
for domain in all_domains:
|
||||
for event_class in all_classes:
|
||||
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
|
||||
channels.append(f"tenant:{tenant_id}:recommendations")
|
||||
channels.append(f"alerts:{tenant_id}") # Legacy
|
||||
return channels
|
||||
|
||||
# Parse filters and expand wildcards
|
||||
for filter_pattern in channel_filters:
|
||||
if filter_pattern == "*.*":
|
||||
for domain in all_domains:
|
||||
for event_class in all_classes:
|
||||
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
|
||||
channels.append(f"tenant:{tenant_id}:recommendations")
|
||||
elif filter_pattern.endswith(".*"):
|
||||
domain = filter_pattern.split(".")[0]
|
||||
for event_class in all_classes:
|
||||
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
|
||||
elif filter_pattern.startswith("*."):
|
||||
event_class = filter_pattern.split(".")[1]
|
||||
if event_class == "recommendations":
|
||||
channels.append(f"tenant:{tenant_id}:recommendations")
|
||||
else:
|
||||
for domain in all_domains:
|
||||
channels.append(f"tenant:{tenant_id}:{domain}.{event_class}")
|
||||
elif filter_pattern == "recommendations":
|
||||
channels.append(f"tenant:{tenant_id}:recommendations")
|
||||
else:
|
||||
channels.append(f"tenant:{tenant_id}:{filter_pattern}")
|
||||
|
||||
return list(set(channels))
|
||||
|
||||
|
||||
async def _load_initial_state(redis_client, tenant_id: str, channel_filters: list) -> list:
|
||||
"""Load initial state from Redis cache based on channel filters"""
|
||||
initial_events = []
|
||||
|
||||
try:
|
||||
if not channel_filters:
|
||||
# Legacy cache
|
||||
legacy_cache_key = f"active_alerts:{tenant_id}"
|
||||
cached_data = await redis_client.get(legacy_cache_key)
|
||||
if cached_data:
|
||||
return json.loads(cached_data)
|
||||
|
||||
# New domain-specific caches
|
||||
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
|
||||
all_classes = ["alerts", "notifications"]
|
||||
|
||||
for domain in all_domains:
|
||||
for event_class in all_classes:
|
||||
cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s"
|
||||
cached_data = await redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
initial_events.extend(json.loads(cached_data))
|
||||
|
||||
# Recommendations
|
||||
recommendations_cache_key = f"active_events:{tenant_id}:recommendations"
|
||||
cached_data = await redis_client.get(recommendations_cache_key)
|
||||
if cached_data:
|
||||
initial_events.extend(json.loads(cached_data))
|
||||
|
||||
return initial_events
|
||||
|
||||
# Load based on specific filters
|
||||
for filter_pattern in channel_filters:
|
||||
if "." in filter_pattern:
|
||||
parts = filter_pattern.split(".")
|
||||
domain = parts[0] if parts[0] != "*" else None
|
||||
event_class = parts[1] if len(parts) > 1 and parts[1] != "*" else None
|
||||
|
||||
if domain and event_class:
|
||||
cache_key = f"active_events:{tenant_id}:{domain}.{event_class}s"
|
||||
cached_data = await redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
initial_events.extend(json.loads(cached_data))
|
||||
elif domain and not event_class:
|
||||
for ec in ["alerts", "notifications"]:
|
||||
cache_key = f"active_events:{tenant_id}:{domain}.{ec}"
|
||||
cached_data = await redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
initial_events.extend(json.loads(cached_data))
|
||||
elif not domain and event_class:
|
||||
all_domains = ["inventory", "production", "supply_chain", "demand", "operations"]
|
||||
for d in all_domains:
|
||||
cache_key = f"active_events:{tenant_id}:{d}.{event_class}s"
|
||||
cached_data = await redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
initial_events.extend(json.loads(cached_data))
|
||||
elif filter_pattern == "recommendations":
|
||||
cache_key = f"active_events:{tenant_id}:recommendations"
|
||||
cached_data = await redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
initial_events.extend(json.loads(cached_data))
|
||||
|
||||
return initial_events
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading initial state for tenant {tenant_id}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _determine_event_type(event_data: dict) -> str:
|
||||
"""Determine SSE event type from event data"""
|
||||
if 'event_class' in event_data:
|
||||
return event_data['event_class']
|
||||
if 'item_type' in event_data:
|
||||
if event_data['item_type'] == 'recommendation':
|
||||
return 'recommendation'
|
||||
else:
|
||||
return 'alert'
|
||||
return 'alert'
|
||||
|
||||
|
||||
# ================================================================
|
||||
# SERVER-SENT EVENTS (SSE) ENDPOINT
|
||||
# ================================================================
|
||||
|
||||
@app.get("/api/v1/events")
|
||||
async def events_stream(
|
||||
request: Request,
|
||||
tenant_id: str,
|
||||
channels: str = None
|
||||
):
|
||||
"""
|
||||
Server-Sent Events stream for real-time notifications with multi-channel support.
|
||||
|
||||
Query Parameters:
|
||||
tenant_id: Tenant identifier (required)
|
||||
channels: Comma-separated channel filters (optional)
|
||||
"""
|
||||
global redis_client
|
||||
|
||||
if not redis_client:
|
||||
raise HTTPException(status_code=503, detail="SSE service unavailable")
|
||||
|
||||
# Extract user context from request state
|
||||
user_context = request.state.user
|
||||
user_id = user_context.get('user_id')
|
||||
email = user_context.get('email')
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="tenant_id query parameter is required")
|
||||
|
||||
# Parse channel filters
|
||||
channel_filters = []
|
||||
if channels:
|
||||
channel_filters = [c.strip() for c in channels.split(',') if c.strip()]
|
||||
|
||||
logger.info(f"SSE connection request for user {email}, tenant {tenant_id}, channels: {channel_filters or 'all'}")
|
||||
|
||||
async def event_generator():
|
||||
"""Generate server-sent events from Redis pub/sub"""
|
||||
pubsub = None
|
||||
try:
|
||||
pubsub = redis_client.pubsub()
|
||||
logger.debug(f"Created Redis pubsub connection for tenant: {tenant_id}")
|
||||
|
||||
# Determine channels
|
||||
subscription_channels = _get_subscription_channels(tenant_id, channel_filters)
|
||||
|
||||
# Subscribe
|
||||
if subscription_channels:
|
||||
await pubsub.subscribe(*subscription_channels)
|
||||
logger.info(f"Subscribed to {len(subscription_channels)} channels for tenant {tenant_id}")
|
||||
else:
|
||||
legacy_channel = f"alerts:{tenant_id}"
|
||||
await pubsub.subscribe(legacy_channel)
|
||||
|
||||
# Connection event
|
||||
yield f"event: connection\n"
|
||||
yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'channels': subscription_channels or ['all'], 'timestamp': time.time()})}\n\n"
|
||||
|
||||
# Initial state
|
||||
initial_events = await _load_initial_state(redis_client, tenant_id, channel_filters)
|
||||
if initial_events:
|
||||
logger.info(f"Sending {len(initial_events)} initial events to tenant {tenant_id}")
|
||||
yield f"event: initial_state\n"
|
||||
yield f"data: {json.dumps(initial_events)}\n\n"
|
||||
|
||||
heartbeat_counter = 0
|
||||
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
logger.info(f"SSE client disconnected for tenant: {tenant_id}")
|
||||
break
|
||||
|
||||
try:
|
||||
message = await asyncio.wait_for(pubsub.get_message(ignore_subscribe_messages=True), timeout=10.0)
|
||||
|
||||
if message and message['type'] == 'message':
|
||||
event_data = json.loads(message['data'])
|
||||
event_type = _determine_event_type(event_data)
|
||||
event_data['_channel'] = message['channel'].decode('utf-8') if isinstance(message['channel'], bytes) else message['channel']
|
||||
|
||||
yield f"event: {event_type}\n"
|
||||
yield f"data: {json.dumps(event_data)}\n\n"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
heartbeat_counter += 1
|
||||
if heartbeat_counter >= 10:
|
||||
yield f"event: heartbeat\n"
|
||||
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n"
|
||||
heartbeat_counter = 0
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"SSE connection cancelled for tenant: {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"SSE error for tenant {tenant_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
if pubsub:
|
||||
try:
|
||||
await pubsub.unsubscribe()
|
||||
await pubsub.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing pubsub: {e}")
|
||||
logger.info(f"SSE connection closed for tenant: {tenant_id}")
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Headers": "Cache-Control",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# WEBSOCKET ROUTING FOR TRAINING SERVICE
|
||||
# ================================================================
|
||||
|
||||
@app.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
||||
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
|
||||
"""WebSocket proxy with token verification for training service"""
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
await websocket.accept()
|
||||
await websocket.close(code=1008, reason="Authentication token required")
|
||||
return
|
||||
|
||||
# Verify token
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload or not payload.get('user_id'):
|
||||
await websocket.accept()
|
||||
await websocket.close(code=1008, reason="Invalid token")
|
||||
return
|
||||
except Exception as e:
|
||||
await websocket.accept()
|
||||
await websocket.close(code=1008, reason="Token verification failed")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
# Build WebSocket URL to training service
|
||||
training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/')
|
||||
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
|
||||
training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
|
||||
|
||||
training_ws = None
|
||||
|
||||
try:
|
||||
import websockets
|
||||
from websockets.protocol import State
|
||||
|
||||
training_ws = await websockets.connect(
|
||||
training_ws_url,
|
||||
ping_interval=120,
|
||||
ping_timeout=60,
|
||||
close_timeout=60,
|
||||
open_timeout=30
|
||||
)
|
||||
|
||||
async def forward_frontend_to_training():
|
||||
try:
|
||||
while training_ws and training_ws.state == State.OPEN:
|
||||
data = await websocket.receive()
|
||||
if data.get("type") == "websocket.receive":
|
||||
if "text" in data:
|
||||
await training_ws.send(data["text"])
|
||||
elif "bytes" in data:
|
||||
await training_ws.send(data["bytes"])
|
||||
elif data.get("type") == "websocket.disconnect":
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def forward_training_to_frontend():
|
||||
try:
|
||||
while training_ws and training_ws.state == State.OPEN:
|
||||
message = await training_ws.recv()
|
||||
await websocket.send_text(message)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.gather(
|
||||
forward_frontend_to_training(),
|
||||
forward_training_to_frontend(),
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("WebSocket proxy error", job_id=job_id, error=str(e))
|
||||
finally:
|
||||
if training_ws and training_ws.state == State.OPEN:
|
||||
try:
|
||||
await training_ws.close()
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
if not websocket.client_state.name == 'DISCONNECTED':
|
||||
await websocket.close(code=1000, reason="Proxy closed")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
0
gateway/app/middleware/__init__.py
Normal file
0
gateway/app/middleware/__init__.py
Normal file
649
gateway/app/middleware/auth.py
Normal file
649
gateway/app/middleware/auth.py
Normal file
@@ -0,0 +1,649 @@
|
||||
# gateway/app/middleware/auth.py
|
||||
"""
|
||||
Enhanced Authentication Middleware for API Gateway with Tenant Access Control
|
||||
FIXED VERSION - Proper JWT verification and token structure handling
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from typing import Optional, Dict, Any
|
||||
import httpx
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# JWT handler for local token validation - using SAME configuration as auth service
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
# Routes that don't require authentication
|
||||
PUBLIC_ROUTES = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/refresh",
|
||||
"/api/v1/auth/verify",
|
||||
"/api/v1/auth/start-registration", # Registration step 1 - SetupIntent creation
|
||||
"/api/v1/auth/complete-registration", # Registration step 2 - Completion after 3DS
|
||||
"/api/v1/registration/payment-setup", # New registration payment setup endpoint
|
||||
"/api/v1/registration/complete", # New registration completion endpoint
|
||||
"/api/v1/registration/state/", # Registration state check
|
||||
"/api/v1/auth/verify-email", # Email verification
|
||||
"/api/v1/auth/password/reset-request", # Password reset request - no auth required
|
||||
"/api/v1/auth/password/reset", # Password reset with token - no auth required
|
||||
"/api/v1/nominatim/search",
|
||||
"/api/v1/plans",
|
||||
"/api/v1/demo/accounts",
|
||||
"/api/v1/demo/sessions",
|
||||
"/api/v1/webhooks/stripe", # Stripe webhook endpoint - bypasses auth for signature verification
|
||||
"/api/v1/webhooks/generic", # Generic webhook endpoint
|
||||
"/api/v1/telemetry/v1/traces", # Frontend telemetry traces - no auth for performance
|
||||
"/api/v1/telemetry/v1/metrics", # Frontend telemetry metrics - no auth for performance
|
||||
"/api/v1/telemetry/health" # Telemetry health check
|
||||
]
|
||||
|
||||
# Routes accessible with demo session (no JWT required, just demo session header)
|
||||
DEMO_ACCESSIBLE_ROUTES = [
|
||||
"/api/v1/tenants/", # All tenant endpoints accessible in demo mode
|
||||
]
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Enhanced Authentication Middleware with Tenant Access Control
|
||||
"""
|
||||
|
||||
def __init__(self, app, redis_client=None):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client # For caching and rate limiting
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with enhanced authentication and tenant access control"""
|
||||
|
||||
# Skip authentication for OPTIONS requests (CORS preflight)
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
|
||||
# SECURITY: Remove any incoming sensitive headers using HeaderManager
|
||||
header_manager.sanitize_incoming_headers(request)
|
||||
|
||||
# Skip authentication for public routes
|
||||
if self._is_public_route(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# ✅ Check if demo middleware already set user context OR check query param for SSE
|
||||
demo_session_header = request.headers.get("X-Demo-Session-Id")
|
||||
demo_session_query = request.query_params.get("demo_session_id") # For SSE endpoint
|
||||
logger.info(f"Auth check - path: {request.url.path}, demo_header: {demo_session_header}, demo_query: {demo_session_query}, has_demo_state: {hasattr(request.state, 'is_demo_session')}")
|
||||
|
||||
# For SSE endpoint with demo_session_id in query params, validate it here
|
||||
if request.url.path == "/api/v1/events" and demo_session_query and not hasattr(request.state, "is_demo_session"):
|
||||
logger.info(f"SSE endpoint with demo_session_id query param: {demo_session_query}")
|
||||
# Validate demo session via demo-session service using JWT service token
|
||||
import httpx
|
||||
try:
|
||||
# Create service token for gateway-to-demo-session communication
|
||||
service_token = jwt_handler.create_service_token(service_name="gateway")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"http://demo-session-service:8000/api/v1/demo/sessions/{demo_session_query}",
|
||||
headers={"Authorization": f"Bearer {service_token}"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
session_data = response.json()
|
||||
# Set demo session context
|
||||
request.state.is_demo_session = True
|
||||
request.state.user = {
|
||||
"user_id": f"demo-user-{demo_session_query}",
|
||||
"email": f"demo-{demo_session_query}@demo.local",
|
||||
"tenant_id": session_data.get("virtual_tenant_id"),
|
||||
"demo_session_id": demo_session_query,
|
||||
}
|
||||
request.state.tenant_id = session_data.get("virtual_tenant_id")
|
||||
logger.info(f"✅ Demo session validated for SSE: {demo_session_query}")
|
||||
else:
|
||||
logger.warning(f"Invalid demo session for SSE: {demo_session_query}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid demo session"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate demo session for SSE: {e}")
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"detail": "Demo session service unavailable"}
|
||||
)
|
||||
|
||||
if hasattr(request.state, "is_demo_session") and request.state.is_demo_session:
|
||||
if hasattr(request.state, "user") and request.state.user:
|
||||
logger.info(f"✅ Demo session authenticated for route: {request.url.path}")
|
||||
# Demo middleware already validated and set user context
|
||||
# But we still need to inject context headers for downstream services
|
||||
user_context = request.state.user
|
||||
tenant_id = user_context.get("tenant_id") or getattr(request.state, "tenant_id", None)
|
||||
|
||||
# For demo sessions, get the actual subscription tier from the tenant service
|
||||
# instead of always defaulting to enterprise
|
||||
if not user_context.get("subscription_tier"):
|
||||
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
|
||||
if subscription_tier:
|
||||
user_context["subscription_tier"] = subscription_tier
|
||||
else:
|
||||
# Fallback to enterprise for demo if no tier is found
|
||||
user_context["subscription_tier"] = "enterprise"
|
||||
|
||||
logger.debug(f"Demo session subscription tier set to {user_context['subscription_tier']}", tenant_id=tenant_id)
|
||||
|
||||
await self._inject_context_headers(request, user_context, tenant_id)
|
||||
return await call_next(request)
|
||||
|
||||
# ✅ STEP 1: Extract and validate JWT token
|
||||
token = self._extract_token(request)
|
||||
if not token:
|
||||
logger.warning(f"❌ Missing token for protected route: {request.url.path}, demo_header: {demo_session_header}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Authentication required"}
|
||||
)
|
||||
|
||||
# ✅ STEP 2: Verify token and get user context
|
||||
user_context = await self._verify_token(token, request)
|
||||
if not user_context:
|
||||
logger.warning(f"Invalid token for route: {request.url.path}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "User not authenticated"}
|
||||
)
|
||||
|
||||
# ✅ STEP 3: Extract tenant context from URL using shared utility
|
||||
tenant_id = extract_tenant_id_from_path(request.url.path)
|
||||
|
||||
# ✅ STEP 4: Verify tenant access if this is a tenant-scoped route
|
||||
if tenant_id and is_tenant_scoped_path(request.url.path):
|
||||
# Skip tenant access verification for service tokens (services have admin access)
|
||||
if user_context.get("type") != "service":
|
||||
# Use TenantAccessManager for gateway-level verification with caching
|
||||
if self.redis_client and tenant_access_manager.redis_client is None:
|
||||
tenant_access_manager.redis_client = self.redis_client
|
||||
|
||||
has_access = await tenant_access_manager.verify_basic_tenant_access(
|
||||
user_context["user_id"],
|
||||
tenant_id
|
||||
)
|
||||
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": f"Access denied to tenant {tenant_id}"}
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Service token granted access to tenant {tenant_id}",
|
||||
service=user_context.get("service"))
|
||||
|
||||
# Get tenant subscription tier and inject into user context
|
||||
# NEW: Use JWT data if available, skip HTTP call
|
||||
if user_context.get("subscription_from_jwt"):
|
||||
subscription_tier = user_context.get("subscription_tier")
|
||||
logger.debug("Using subscription tier from JWT", tier=subscription_tier)
|
||||
else:
|
||||
# Only for old tokens - remove after full rollout
|
||||
subscription_tier = await self._get_tenant_subscription_tier(tenant_id, request)
|
||||
|
||||
if subscription_tier:
|
||||
user_context["subscription_tier"] = subscription_tier
|
||||
|
||||
# Check hierarchical access to determine access type and permissions
|
||||
hierarchical_access = await tenant_access_manager.verify_hierarchical_access(
|
||||
user_context["user_id"],
|
||||
tenant_id
|
||||
)
|
||||
|
||||
# Set tenant context in request state
|
||||
request.state.tenant_id = tenant_id
|
||||
request.state.tenant_verified = True
|
||||
request.state.tenant_access_type = hierarchical_access.get("access_type", "direct")
|
||||
request.state.can_view_children = hierarchical_access.get("can_view_children", False)
|
||||
|
||||
logger.debug(f"Tenant access verified",
|
||||
user_id=user_context["user_id"],
|
||||
tenant_id=tenant_id,
|
||||
subscription_tier=subscription_tier,
|
||||
access_type=hierarchical_access.get("access_type"),
|
||||
can_view_children=hierarchical_access.get("can_view_children"),
|
||||
path=request.url.path)
|
||||
|
||||
# ✅ STEP 5: Inject user context into request
|
||||
request.state.user = user_context
|
||||
request.state.authenticated = True
|
||||
|
||||
# ✅ STEP 6: Add context headers for downstream services
|
||||
await self._inject_context_headers(request, user_context, tenant_id)
|
||||
|
||||
logger.debug(f"Authenticated request",
|
||||
user_email=user_context['email'],
|
||||
tenant_id=tenant_id,
|
||||
path=request.url.path)
|
||||
|
||||
# Process the request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add token expiry warning header if token is near expiry
|
||||
if hasattr(request.state, 'token_near_expiry') and request.state.token_near_expiry:
|
||||
response.headers["X-Token-Refresh-Suggested"] = "true"
|
||||
|
||||
return response
|
||||
|
||||
def _is_public_route(self, path: str) -> bool:
|
||||
"""Check if route requires authentication"""
|
||||
return any(path.startswith(route) for route in PUBLIC_ROUTES)
|
||||
|
||||
def _extract_token(self, request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract JWT token from Authorization header or query params for SSE.
|
||||
|
||||
For SSE endpoints (/api/v1/events), browsers' EventSource API cannot send
|
||||
custom headers, so we must accept token as query parameter.
|
||||
For all other routes, token must be in Authorization header (more secure).
|
||||
|
||||
Security note: Query param tokens are logged. Use short expiry and filter logs.
|
||||
"""
|
||||
# SSE endpoint exception: token in query param (EventSource API limitation)
|
||||
if request.url.path == "/api/v1/events":
|
||||
token = request.query_params.get("token")
|
||||
if token:
|
||||
logger.debug("Token extracted from query param for SSE endpoint")
|
||||
return token
|
||||
logger.warning("SSE request missing token in query param")
|
||||
return None
|
||||
|
||||
# Standard authentication: Authorization header for all other routes
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header.split(" ")[1]
|
||||
|
||||
return None
|
||||
|
||||
async def _verify_token(self, token: str, request: Request = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verify JWT token with improved fallback strategy
|
||||
FIXED: Better error handling and token structure validation
|
||||
"""
|
||||
|
||||
# Strategy 1: Try local JWT validation first (fast)
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if payload and self._validate_token_payload(payload):
|
||||
logger.debug("Token validated locally")
|
||||
|
||||
# NEW: Check token freshness for subscription changes (async)
|
||||
if payload.get("tenant_id") and request:
|
||||
try:
|
||||
is_fresh = await self._verify_token_freshness(payload, payload["tenant_id"])
|
||||
if not is_fresh:
|
||||
logger.warning("Stale token detected - subscription changed since token was issued",
|
||||
user_id=payload.get("user_id"),
|
||||
tenant_id=payload.get("tenant_id"))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token is stale - subscription has changed"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Token freshness check failed, allowing token", error=str(e))
|
||||
# Allow token if check fails (fail open for availability)
|
||||
|
||||
# Check if token is near expiry and set flag for response header
|
||||
if request:
|
||||
import time
|
||||
exp_time = payload.get("exp", 0)
|
||||
current_time = time.time()
|
||||
time_until_expiry = exp_time - current_time
|
||||
|
||||
if time_until_expiry < 300: # 5 minutes
|
||||
request.state.token_near_expiry = True
|
||||
|
||||
# Convert JWT payload to user context format
|
||||
return self._jwt_payload_to_user_context(payload)
|
||||
except Exception as e:
|
||||
logger.debug(f"Local token validation failed: {e}")
|
||||
|
||||
# Strategy 2: Check cache for recently validated tokens
|
||||
if self.redis_client:
|
||||
try:
|
||||
cached_user = await self._get_cached_user(token)
|
||||
if cached_user:
|
||||
logger.debug("Token found in cache")
|
||||
return cached_user
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache lookup failed: {e}")
|
||||
|
||||
# Strategy 3: Verify with auth service (authoritative)
|
||||
try:
|
||||
user_context = await self._verify_with_auth_service(token)
|
||||
if user_context:
|
||||
# Cache successful validation
|
||||
if self.redis_client:
|
||||
await self._cache_user(token, user_context)
|
||||
logger.debug("Token validated by auth service")
|
||||
return user_context
|
||||
except Exception as e:
|
||||
logger.error(f"Auth service validation failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _validate_token_payload(self, payload: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate JWT payload has required fields
|
||||
FIXED: Updated to match actual token structure from auth service
|
||||
"""
|
||||
required_fields = ["user_id", "email", "exp", "type"]
|
||||
missing_fields = [field for field in required_fields if field not in payload]
|
||||
|
||||
if missing_fields:
|
||||
logger.warning(f"Token payload missing fields: {missing_fields}")
|
||||
return False
|
||||
|
||||
# Validate token type
|
||||
token_type = payload.get("type")
|
||||
if token_type not in ["access", "service"]:
|
||||
logger.warning(f"Invalid token type: {payload.get('type')}")
|
||||
return False
|
||||
|
||||
# Check if token is near expiry (within 5 minutes) and log warning
|
||||
import time
|
||||
exp_time = payload.get("exp", 0)
|
||||
current_time = time.time()
|
||||
time_until_expiry = exp_time - current_time
|
||||
|
||||
if time_until_expiry < 300: # 5 minutes
|
||||
logger.warning(f"Token expires in {int(time_until_expiry)} seconds for user {payload.get('email')}")
|
||||
|
||||
# NEW: Check token freshness for subscription changes
|
||||
if payload.get("tenant_id"):
|
||||
try:
|
||||
# Note: We can't await here because this is a sync function
|
||||
# Token freshness will be checked in the async dispatch method
|
||||
# For now, just log that we would check freshness
|
||||
logger.debug("Token freshness check would be performed in async context",
|
||||
tenant_id=payload.get("tenant_id"))
|
||||
except Exception as e:
|
||||
logger.warning("Token freshness check setup failed", error=str(e))
|
||||
|
||||
# FIX: Validate service tokens with tenant context for tenant-scoped routes
|
||||
if token_type == "service" and payload.get("tenant_id"):
|
||||
# Service tokens with tenant context are valid for tenant-scoped operations
|
||||
logger.debug("Service token with tenant context validated",
|
||||
service=payload.get("service"), tenant_id=payload.get("tenant_id"))
|
||||
|
||||
return True
|
||||
|
||||
def _validate_jwt_integrity(self, payload: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate JWT payload integrity beyond signature verification.
|
||||
Prevents edge cases where payload might be malformed.
|
||||
"""
|
||||
# Required fields must exist
|
||||
required_fields = ["user_id", "email", "exp", "iat", "iss"]
|
||||
if not all(field in payload for field in required_fields):
|
||||
logger.warning("JWT missing required fields", missing=[f for f in required_fields if f not in payload])
|
||||
return False
|
||||
|
||||
# Issuer must be our auth service
|
||||
if payload.get("iss") != "bakery-auth":
|
||||
logger.warning("JWT has invalid issuer", issuer=payload.get("iss"))
|
||||
return False
|
||||
|
||||
# Token type must be valid
|
||||
if payload.get("type") not in ["access", "service"]:
|
||||
logger.warning("JWT has invalid type", token_type=payload.get("type"))
|
||||
return False
|
||||
|
||||
# Subscription tier must be valid if present
|
||||
valid_tiers = ["starter", "professional", "enterprise"]
|
||||
if payload.get("subscription"):
|
||||
tier = payload["subscription"].get("tier", "").lower()
|
||||
if tier and tier not in valid_tiers:
|
||||
logger.warning("JWT has invalid subscription tier", tier=tier)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _verify_token_freshness(self, payload: Dict[str, Any], tenant_id: str) -> bool:
|
||||
"""
|
||||
Verify token was issued after the last subscription change.
|
||||
Prevents use of stale tokens with old subscription data.
|
||||
"""
|
||||
if not self.redis_client:
|
||||
return True # Skip check if no Redis
|
||||
|
||||
try:
|
||||
subscription_changed_at = await self.redis_client.get(
|
||||
f"tenant:{tenant_id}:subscription_changed_at"
|
||||
)
|
||||
|
||||
if subscription_changed_at:
|
||||
changed_timestamp = float(subscription_changed_at)
|
||||
token_issued_at = payload.get("iat", 0)
|
||||
|
||||
if token_issued_at < changed_timestamp:
|
||||
logger.warning(
|
||||
"Token issued before subscription change",
|
||||
token_iat=token_issued_at,
|
||||
subscription_changed=changed_timestamp,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return False # Token is stale
|
||||
except Exception as e:
|
||||
logger.warning("Failed to check token freshness", error=str(e))
|
||||
|
||||
return True
|
||||
|
||||
def _jwt_payload_to_user_context(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert JWT payload to user context format
|
||||
FIXED: Proper mapping between JWT structure and user context
|
||||
"""
|
||||
# NEW: Validate JWT integrity before processing
|
||||
if not self._validate_jwt_integrity(payload):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid JWT payload"
|
||||
)
|
||||
|
||||
base_context = {
|
||||
"user_id": payload["user_id"],
|
||||
"email": payload["email"],
|
||||
"exp": payload["exp"],
|
||||
"valid": True,
|
||||
"role": payload.get("role", "user"),
|
||||
}
|
||||
|
||||
# NEW: Extract subscription from JWT
|
||||
if payload.get("tenant_id"):
|
||||
base_context["tenant_id"] = payload["tenant_id"]
|
||||
base_context["tenant_role"] = payload.get("tenant_role", "member")
|
||||
|
||||
if payload.get("subscription"):
|
||||
sub = payload["subscription"]
|
||||
base_context["subscription_tier"] = sub.get("tier", "starter")
|
||||
base_context["subscription_status"] = sub.get("status", "active")
|
||||
base_context["subscription_from_jwt"] = True # Flag to skip HTTP
|
||||
|
||||
if payload.get("tenant_access"):
|
||||
base_context["tenant_access"] = payload["tenant_access"]
|
||||
|
||||
if payload.get("service"):
|
||||
service_name = payload["service"]
|
||||
base_context["service"] = service_name
|
||||
base_context["type"] = "service"
|
||||
base_context["role"] = "admin"
|
||||
base_context["user_id"] = f"{service_name}-service"
|
||||
base_context["email"] = f"{service_name}-service@internal"
|
||||
|
||||
# FIX: Service tokens with tenant context should use that tenant_id
|
||||
if payload.get("tenant_id"):
|
||||
base_context["tenant_id"] = payload["tenant_id"]
|
||||
logger.debug(f"Service authentication with tenant context: {service_name}, tenant_id: {payload['tenant_id']}")
|
||||
else:
|
||||
logger.debug(f"Service authentication: {service_name}")
|
||||
|
||||
return base_context
|
||||
|
||||
async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verify token with auth service
|
||||
FIXED: Improved error handling and response parsing
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.post(
|
||||
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
auth_response = response.json()
|
||||
|
||||
# Validate auth service response structure
|
||||
if auth_response.get("valid") and auth_response.get("user_id"):
|
||||
return {
|
||||
"user_id": auth_response["user_id"],
|
||||
"email": auth_response["email"],
|
||||
"exp": auth_response.get("exp"),
|
||||
"valid": True
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Auth service returned invalid response: {auth_response}")
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"Auth service returned {response.status_code}: {response.text}")
|
||||
return None
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Auth service timeout during token verification")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Auth service error: {e}")
|
||||
return None
|
||||
|
||||
async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get user context from cache
|
||||
FIXED: Better error handling and JSON parsing
|
||||
"""
|
||||
if not self.redis_client:
|
||||
return None
|
||||
|
||||
cache_key = f"auth:token:{hash(token) % 1000000}" # Use modulo for shorter keys
|
||||
try:
|
||||
cached_data = await self.redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
if isinstance(cached_data, bytes):
|
||||
cached_data = cached_data.decode()
|
||||
return json.loads(cached_data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse cached user data: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache lookup error: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _cache_user(self, token: str, user_context: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Cache user context
|
||||
FIXED: Better error handling and expiration
|
||||
"""
|
||||
if not self.redis_client:
|
||||
return
|
||||
|
||||
cache_key = f"auth:token:{hash(token) % 1000000}"
|
||||
try:
|
||||
# Cache for 5 minutes (shorter than token expiry)
|
||||
await self.redis_client.setex(
|
||||
cache_key,
|
||||
300, # 5 minutes
|
||||
json.dumps(user_context)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache user context: {e}")
|
||||
|
||||
async def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str] = None):
|
||||
"""
|
||||
Inject user and tenant context headers for downstream services using unified HeaderManager
|
||||
"""
|
||||
# Use unified HeaderManager for consistent header injection
|
||||
injected_headers = header_manager.inject_context_headers(request, user_context, tenant_id)
|
||||
|
||||
# Add hierarchical access headers if tenant context exists
|
||||
if tenant_id:
|
||||
# If this is hierarchical access, include parent tenant ID
|
||||
# Get parent tenant ID from the auth service if available
|
||||
try:
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/hierarchy",
|
||||
headers={"Authorization": request.headers.get("Authorization", "")}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
hierarchy_data = response.json()
|
||||
parent_tenant_id = hierarchy_data.get("parent_tenant_id")
|
||||
if parent_tenant_id:
|
||||
# Add parent tenant ID using HeaderManager for consistency
|
||||
header_name = header_manager.STANDARD_HEADERS['parent_tenant_id']
|
||||
header_value = str(parent_tenant_id)
|
||||
header_manager.add_header_for_middleware(request, header_name, header_value)
|
||||
logger.info(f"Added parent tenant ID header: {parent_tenant_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get parent tenant ID: {e}")
|
||||
pass
|
||||
|
||||
return injected_headers
|
||||
|
||||
async def _get_tenant_subscription_tier(self, tenant_id: str, request: Request) -> Optional[str]:
|
||||
"""
|
||||
Get tenant subscription tier using fast cached endpoint
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
request: FastAPI request for headers
|
||||
|
||||
Returns:
|
||||
Subscription tier string or None
|
||||
"""
|
||||
try:
|
||||
# Use fast cached subscription tier endpoint (has its own Redis caching)
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
headers = {"Authorization": request.headers.get("Authorization", "")}
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
tier_data = response.json()
|
||||
subscription_tier = tier_data.get("tier", "starter")
|
||||
|
||||
logger.debug("Subscription tier from cached endpoint",
|
||||
tenant_id=tenant_id,
|
||||
tier=subscription_tier,
|
||||
cached=tier_data.get("cached", False))
|
||||
return subscription_tier
|
||||
else:
|
||||
logger.warning(f"Failed to get tenant subscription tier: {response.status_code}")
|
||||
return "starter" # Default to starter
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tenant subscription tier: {e}")
|
||||
return "starter" # Default to starter on error
|
||||
384
gateway/app/middleware/demo_middleware.py
Normal file
384
gateway/app/middleware/demo_middleware.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Demo Session Middleware
|
||||
Handles demo account restrictions and virtual tenant injection
|
||||
"""
|
||||
|
||||
from fastapi import Request, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from typing import Optional
|
||||
import uuid
|
||||
import httpx
|
||||
import structlog
|
||||
import json
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Fixed Demo Tenant IDs (these are the template tenants that will be cloned)
|
||||
# Professional demo (merged from San Pablo + La Espiga)
|
||||
DEMO_TENANT_PROFESSIONAL = uuid.UUID("a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6")
|
||||
|
||||
# Enterprise chain demo (parent + 3 children)
|
||||
DEMO_TENANT_ENTERPRISE_CHAIN = uuid.UUID("c3d4e5f6-a7b8-49c0-d1e2-f3a4b5c6d7e8")
|
||||
DEMO_TENANT_CHILD_1 = uuid.UUID("d4e5f6a7-b8c9-40d1-e2f3-a4b5c6d7e8f9")
|
||||
DEMO_TENANT_CHILD_2 = uuid.UUID("e5f6a7b8-c9d0-41e2-f3a4-b5c6d7e8f9a0")
|
||||
DEMO_TENANT_CHILD_3 = uuid.UUID("f6a7b8c9-d0e1-42f3-a4b5-c6d7e8f9a0b1")
|
||||
|
||||
# Demo tenant IDs (base templates)
|
||||
DEMO_TENANT_IDS = {
|
||||
str(DEMO_TENANT_PROFESSIONAL), # Professional demo tenant
|
||||
str(DEMO_TENANT_ENTERPRISE_CHAIN), # Enterprise chain parent
|
||||
str(DEMO_TENANT_CHILD_1), # Enterprise chain child 1
|
||||
str(DEMO_TENANT_CHILD_2), # Enterprise chain child 2
|
||||
str(DEMO_TENANT_CHILD_3), # Enterprise chain child 3
|
||||
}
|
||||
|
||||
# Demo user IDs - Maps demo account type to actual user UUIDs from fixture files
|
||||
# These IDs are the owner IDs from the respective 01-tenant.json files
|
||||
DEMO_USER_IDS = {
|
||||
"professional": "c1a2b3c4-d5e6-47a8-b9c0-d1e2f3a4b5c6", # María García López (professional/01-tenant.json -> owner.id)
|
||||
"enterprise": "d2e3f4a5-b6c7-48d9-e0f1-a2b3c4d5e6f7" # Director (enterprise/parent/01-tenant.json -> owner.id)
|
||||
}
|
||||
|
||||
# Allowed operations for demo accounts (limited write)
|
||||
DEMO_ALLOWED_OPERATIONS = {
|
||||
# Read operations - all allowed
|
||||
"GET": ["*"],
|
||||
|
||||
# Limited write operations for realistic testing
|
||||
"POST": [
|
||||
"/api/v1/pos/sales",
|
||||
"/api/v1/pos/sessions",
|
||||
"/api/v1/orders",
|
||||
"/api/v1/inventory/adjustments",
|
||||
"/api/v1/sales",
|
||||
"/api/v1/production/batches",
|
||||
"/api/v1/tenants/batch/sales-summary",
|
||||
"/api/v1/tenants/batch/production-summary",
|
||||
"/api/v1/auth/me/onboarding/complete", # Allow completing onboarding (no-op for demos)
|
||||
"/api/v1/tenants/*/notifications/send", # Allow notifications (ML insights, alerts, etc.)
|
||||
# Note: Forecast generation is explicitly blocked (see DEMO_BLOCKED_PATHS)
|
||||
],
|
||||
|
||||
"PUT": [
|
||||
"/api/v1/pos/sales/*",
|
||||
"/api/v1/orders/*",
|
||||
"/api/v1/inventory/stock/*",
|
||||
"/api/v1/auth/me/onboarding/step", # Allow onboarding step updates (no-op for demos)
|
||||
],
|
||||
|
||||
# Blocked operations
|
||||
"DELETE": [], # No deletes allowed
|
||||
"PATCH": [], # No patches allowed
|
||||
}
|
||||
|
||||
# Explicitly blocked paths for demo accounts (even if method would be allowed)
|
||||
# These require trained AI models which demo tenants don't have
|
||||
DEMO_BLOCKED_PATHS = [
|
||||
"/api/v1/forecasts/single",
|
||||
"/api/v1/forecasts/multi-day",
|
||||
"/api/v1/forecasts/batch",
|
||||
]
|
||||
|
||||
DEMO_BLOCKED_PATH_MESSAGE = {
|
||||
"forecasts": {
|
||||
"message": "La generación de pronósticos no está disponible para cuentas demo. "
|
||||
"Las cuentas demo no tienen modelos de IA entrenados.",
|
||||
"message_en": "Forecast generation is not available for demo accounts. "
|
||||
"Demo accounts do not have trained AI models.",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class DemoMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to handle demo session logic with Redis caching"""
|
||||
|
||||
def __init__(self, app, demo_session_url: str = "http://demo-session-service:8000"):
|
||||
super().__init__(app)
|
||||
self.demo_session_url = demo_session_url
|
||||
self._redis_client = None
|
||||
|
||||
async def _get_redis_client(self):
|
||||
"""Get or lazily initialize Redis client"""
|
||||
if self._redis_client is None:
|
||||
try:
|
||||
from shared.redis_utils import get_redis_client
|
||||
self._redis_client = await get_redis_client()
|
||||
logger.debug("Demo middleware: Redis client initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Demo middleware: Failed to get Redis client: {e}. Caching disabled.")
|
||||
self._redis_client = False # Sentinel value to avoid retrying
|
||||
|
||||
return self._redis_client if self._redis_client is not False else None
|
||||
|
||||
async def _get_cached_session(self, session_id: str) -> Optional[dict]:
|
||||
"""Get session info from Redis cache"""
|
||||
try:
|
||||
redis_client = await self._get_redis_client()
|
||||
if not redis_client:
|
||||
return None
|
||||
|
||||
cache_key = f"demo_session:{session_id}"
|
||||
cached_data = await redis_client.get(cache_key)
|
||||
|
||||
if cached_data:
|
||||
logger.debug("Demo middleware: Cache HIT", session_id=session_id)
|
||||
return json.loads(cached_data)
|
||||
else:
|
||||
logger.debug("Demo middleware: Cache MISS", session_id=session_id)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Demo middleware: Redis cache read error: {e}")
|
||||
return None
|
||||
|
||||
async def _cache_session(self, session_id: str, session_info: dict, ttl: int = 30):
|
||||
"""Cache session info in Redis with TTL"""
|
||||
try:
|
||||
redis_client = await self._get_redis_client()
|
||||
if not redis_client:
|
||||
return
|
||||
|
||||
cache_key = f"demo_session:{session_id}"
|
||||
serialized = json.dumps(session_info)
|
||||
await redis_client.setex(cache_key, ttl, serialized)
|
||||
logger.debug(f"Demo middleware: Cached session {session_id} (TTL: {ttl}s)")
|
||||
except Exception as e:
|
||||
logger.warning(f"Demo middleware: Redis cache write error: {e}")
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request through demo middleware"""
|
||||
|
||||
# Skip demo middleware for demo service endpoints
|
||||
demo_service_paths = [
|
||||
"/api/v1/demo/accounts",
|
||||
"/api/v1/demo/sessions",
|
||||
"/api/v1/demo/stats",
|
||||
"/api/v1/demo/operations",
|
||||
]
|
||||
|
||||
if any(request.url.path.startswith(path) or request.url.path == path for path in demo_service_paths):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract session ID from header or cookie
|
||||
session_id = (
|
||||
request.headers.get("X-Demo-Session-Id") or
|
||||
request.cookies.get("demo_session_id")
|
||||
)
|
||||
|
||||
logger.info(f"🎭 DemoMiddleware - path: {request.url.path}, session_id: {session_id}")
|
||||
|
||||
# Extract tenant ID from request
|
||||
tenant_id = request.headers.get("X-Tenant-Id")
|
||||
|
||||
# Check if this is a demo session request
|
||||
if session_id:
|
||||
try:
|
||||
# PERFORMANCE OPTIMIZATION: Check Redis cache first before HTTP call
|
||||
session_info = await self._get_cached_session(session_id)
|
||||
|
||||
if not session_info:
|
||||
# Cache miss - fetch from demo service
|
||||
logger.debug("Demo middleware: Fetching from demo service", session_id=session_id)
|
||||
session_info = await self._get_session_info(session_id)
|
||||
|
||||
# Cache the result if successful (30s TTL to balance freshness vs performance)
|
||||
if session_info:
|
||||
await self._cache_session(session_id, session_info, ttl=30)
|
||||
|
||||
# Accept pending, ready, partial, failed (if data exists), and active (deprecated) statuses
|
||||
# Even "failed" sessions can be usable if some services succeeded
|
||||
valid_statuses = ["pending", "ready", "partial", "failed", "active"]
|
||||
current_status = session_info.get("status") if session_info else None
|
||||
|
||||
if session_info and current_status in valid_statuses:
|
||||
# NOTE: Path transformation for demo-user removed.
|
||||
# Frontend now receives the real demo_user_id from session creation
|
||||
# and uses it directly in API calls.
|
||||
|
||||
# Inject virtual tenant ID
|
||||
# Use scope state directly to avoid potential state property issues
|
||||
request.scope.setdefault("state", {})
|
||||
state = request.scope["state"]
|
||||
state["tenant_id"] = session_info["virtual_tenant_id"]
|
||||
state["is_demo_session"] = True
|
||||
state["demo_account_type"] = session_info["demo_account_type"]
|
||||
state["demo_session_status"] = current_status # Track status for monitoring
|
||||
|
||||
# Inject demo user context for auth middleware
|
||||
# Uses DEMO_USER_IDS constant defined at module level
|
||||
demo_user_id = DEMO_USER_IDS.get(
|
||||
session_info.get("demo_account_type", "professional"),
|
||||
DEMO_USER_IDS["professional"]
|
||||
)
|
||||
|
||||
# This allows the request to pass through AuthMiddleware
|
||||
# NEW: Extract subscription tier from demo account type
|
||||
subscription_tier = "enterprise" if session_info.get("demo_account_type") == "enterprise" else "professional"
|
||||
|
||||
state["user"] = {
|
||||
"user_id": demo_user_id, # Use actual demo user UUID
|
||||
"email": f"demo-{session_id}@demo.local",
|
||||
"tenant_id": session_info["virtual_tenant_id"],
|
||||
"role": "owner", # Demo users have owner role
|
||||
"is_demo": True,
|
||||
"demo_session_id": session_id,
|
||||
"demo_account_type": session_info.get("demo_account_type", "professional"),
|
||||
"demo_session_status": current_status,
|
||||
# NEW: Subscription context (no HTTP call needed!)
|
||||
"subscription_tier": subscription_tier,
|
||||
"subscription_status": "active",
|
||||
"subscription_from_jwt": True # Flag to skip HTTP calls
|
||||
}
|
||||
|
||||
# Update activity
|
||||
await self._update_session_activity(session_id)
|
||||
|
||||
# Check if path is explicitly blocked
|
||||
blocked_reason = self._check_blocked_path(request.url.path)
|
||||
if blocked_reason:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"error": "demo_restriction",
|
||||
**blocked_reason,
|
||||
"upgrade_url": "/pricing",
|
||||
"session_expires_at": session_info.get("expires_at")
|
||||
}
|
||||
)
|
||||
|
||||
# Check if operation is allowed
|
||||
if not self._is_operation_allowed(request.method, request.url.path):
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"error": "demo_restriction",
|
||||
"message": "Esta operación no está permitida en cuentas demo. "
|
||||
"Las sesiones demo se eliminan automáticamente después de 30 minutos. "
|
||||
"Suscríbete para obtener acceso completo.",
|
||||
"message_en": "This operation is not allowed in demo accounts. "
|
||||
"Demo sessions are automatically deleted after 30 minutes. "
|
||||
"Subscribe for full access.",
|
||||
"upgrade_url": "/pricing",
|
||||
"session_expires_at": session_info.get("expires_at")
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Session expired, invalid, or in failed/destroyed state
|
||||
logger.warning(f"Invalid demo session state", session_id=session_id, status=current_status)
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": "session_expired",
|
||||
"message": "Tu sesión demo ha expirado. Crea una nueva sesión para continuar.",
|
||||
"message_en": "Your demo session has expired. Create a new session to continue.",
|
||||
"session_status": current_status
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Demo middleware error", error=str(e), session_id=session_id, path=request.url.path)
|
||||
# On error, return 401 instead of continuing
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": "session_error",
|
||||
"message": "Error validando sesión demo. Por favor, inténtalo de nuevo.",
|
||||
"message_en": "Error validating demo session. Please try again."
|
||||
}
|
||||
)
|
||||
|
||||
# Check if this is a demo tenant (base template)
|
||||
elif tenant_id in DEMO_TENANT_IDS:
|
||||
# Direct access to demo tenant without session - block writes
|
||||
request.scope.setdefault("state", {})
|
||||
state = request.scope["state"]
|
||||
state["is_demo_session"] = True
|
||||
state["tenant_id"] = tenant_id
|
||||
|
||||
if request.method not in ["GET", "HEAD", "OPTIONS"]:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"error": "demo_restriction",
|
||||
"message": "Acceso directo al tenant demo no permitido. Crea una sesión demo.",
|
||||
"message_en": "Direct access to demo tenant not allowed. Create a demo session."
|
||||
}
|
||||
)
|
||||
|
||||
# Proceed with request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add demo session header to response if demo session
|
||||
if hasattr(request.state, "is_demo_session") and request.state.is_demo_session:
|
||||
response.headers["X-Demo-Session"] = "true"
|
||||
|
||||
return response
|
||||
|
||||
async def _get_session_info(self, session_id: str) -> Optional[dict]:
|
||||
"""Get session information from demo service using JWT service token"""
|
||||
try:
|
||||
# Create JWT service token for gateway-to-demo-session communication
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from app.core.config import settings
|
||||
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
service_token = jwt_handler.create_service_token(service_name="gateway")
|
||||
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.demo_session_url}/api/v1/demo/sessions/{session_id}",
|
||||
headers={"Authorization": f"Bearer {service_token}"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.warning("Demo session fetch failed",
|
||||
session_id=session_id,
|
||||
status_code=response.status_code,
|
||||
response_text=response.text[:200] if hasattr(response, 'text') else '')
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get session info", session_id=session_id, error=str(e))
|
||||
return None
|
||||
|
||||
async def _update_session_activity(self, session_id: str):
|
||||
"""Update session activity timestamp"""
|
||||
# Note: Activity tracking is handled by the demo service internally
|
||||
# No explicit endpoint needed - activity is updated on session access
|
||||
pass
|
||||
|
||||
def _check_blocked_path(self, path: str) -> Optional[dict]:
|
||||
"""Check if path is explicitly blocked for demo accounts"""
|
||||
for blocked_path in DEMO_BLOCKED_PATHS:
|
||||
if blocked_path in path:
|
||||
# Determine which category of blocked path
|
||||
if "forecast" in blocked_path:
|
||||
return DEMO_BLOCKED_PATH_MESSAGE["forecasts"]
|
||||
# Can add more categories here in the future
|
||||
return {
|
||||
"message": "Esta funcionalidad no está disponible para cuentas demo.",
|
||||
"message_en": "This functionality is not available for demo accounts."
|
||||
}
|
||||
return None
|
||||
|
||||
def _is_operation_allowed(self, method: str, path: str) -> bool:
|
||||
"""Check if method + path combination is allowed for demo"""
|
||||
|
||||
allowed_paths = DEMO_ALLOWED_OPERATIONS.get(method, [])
|
||||
|
||||
# Check for wildcard
|
||||
if "*" in allowed_paths:
|
||||
return True
|
||||
|
||||
# Check for exact match or pattern match
|
||||
for allowed_path in allowed_paths:
|
||||
if allowed_path.endswith("*"):
|
||||
# Pattern match: /api/orders/* matches /api/orders/123
|
||||
if path.startswith(allowed_path[:-1]):
|
||||
return True
|
||||
elif path == allowed_path:
|
||||
# Exact match
|
||||
return True
|
||||
|
||||
return False
|
||||
57
gateway/app/middleware/logging.py
Normal file
57
gateway/app/middleware/logging.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Logging middleware for gateway
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Logging middleware class"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with logging"""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Generate request ID
|
||||
request_id = str(uuid.uuid4())
|
||||
request.state.request_id = request_id
|
||||
|
||||
# Log request
|
||||
logger.info(
|
||||
f"Request: {request.method} {request.url.path}",
|
||||
extra={
|
||||
"method": request.method,
|
||||
"url": request.url.path,
|
||||
"query_params": str(request.query_params),
|
||||
"client_host": request.client.host if request.client else "unknown",
|
||||
"user_agent": request.headers.get("user-agent", ""),
|
||||
"request_id": request_id
|
||||
}
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate duration
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Log response
|
||||
logger.info(
|
||||
f"Response: {response.status_code} in {duration:.3f}s",
|
||||
extra={
|
||||
"status_code": response.status_code,
|
||||
"duration": duration,
|
||||
"method": request.method,
|
||||
"url": request.url.path,
|
||||
"request_id": request_id
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
93
gateway/app/middleware/rate_limit.py
Normal file
93
gateway/app/middleware/rate_limit.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Rate limiting middleware for gateway
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
from typing import Dict, Optional
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting middleware class"""
|
||||
|
||||
def __init__(self, app, calls_per_minute: int = 60):
|
||||
super().__init__(app)
|
||||
self.calls_per_minute = calls_per_minute
|
||||
self.requests: Dict[str, list] = {}
|
||||
self._cleanup_task = None
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with rate limiting"""
|
||||
|
||||
# Skip rate limiting for health checks
|
||||
if request.url.path in ["/health", "/metrics"]:
|
||||
return await call_next(request)
|
||||
|
||||
# Get client identifier
|
||||
client_id = self._get_client_id(request)
|
||||
|
||||
# Check rate limit
|
||||
if self._is_rate_limited(client_id):
|
||||
logger.warning(f"Rate limit exceeded for client: {client_id}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"}
|
||||
)
|
||||
|
||||
# Record request
|
||||
self._record_request(client_id)
|
||||
|
||||
# Process request
|
||||
return await call_next(request)
|
||||
|
||||
def _get_client_id(self, request: Request) -> str:
|
||||
"""Get client identifier"""
|
||||
# Try to get user ID from state (if authenticated)
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
return f"user:{request.state.user.get('user_id', 'unknown')}"
|
||||
|
||||
# Fall back to IP address
|
||||
return f"ip:{request.client.host if request.client else 'unknown'}"
|
||||
|
||||
def _is_rate_limited(self, client_id: str) -> bool:
|
||||
"""Check if client is rate limited"""
|
||||
now = time.time()
|
||||
minute_ago = now - 60
|
||||
|
||||
# Get recent requests for this client
|
||||
if client_id not in self.requests:
|
||||
return False
|
||||
|
||||
# Filter requests from last minute
|
||||
recent_requests = [
|
||||
req_time for req_time in self.requests[client_id]
|
||||
if req_time > minute_ago
|
||||
]
|
||||
|
||||
# Update the list
|
||||
self.requests[client_id] = recent_requests
|
||||
|
||||
# Check if limit exceeded
|
||||
return len(recent_requests) >= self.calls_per_minute
|
||||
|
||||
def _record_request(self, client_id: str):
|
||||
"""Record a request for rate limiting"""
|
||||
now = time.time()
|
||||
|
||||
if client_id not in self.requests:
|
||||
self.requests[client_id] = []
|
||||
|
||||
self.requests[client_id].append(now)
|
||||
|
||||
# Keep only last minute of requests
|
||||
minute_ago = now - 60
|
||||
self.requests[client_id] = [
|
||||
req_time for req_time in self.requests[client_id]
|
||||
if req_time > minute_ago
|
||||
]
|
||||
269
gateway/app/middleware/rate_limiting.py
Normal file
269
gateway/app/middleware/rate_limiting.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
API Rate Limiting Middleware for Gateway
|
||||
Enforces subscription-based API call quotas per hour
|
||||
"""
|
||||
|
||||
import structlog
|
||||
import shared.redis_utils
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import Request, HTTPException, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from typing import Optional
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class APIRateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to enforce API rate limits based on subscription tier.
|
||||
|
||||
Quota limits per hour:
|
||||
- Starter: 100 calls/hour
|
||||
- Professional: 1,000 calls/hour
|
||||
- Enterprise: 10,000 calls/hour
|
||||
|
||||
Uses Redis to track API calls with hourly buckets.
|
||||
"""
|
||||
|
||||
def __init__(self, app, redis_client=None):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
Check API rate limit before processing request.
|
||||
"""
|
||||
# Skip rate limiting for certain paths
|
||||
if self._should_skip_rate_limit(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract tenant_id from request
|
||||
tenant_id = self._extract_tenant_id(request)
|
||||
|
||||
if not tenant_id:
|
||||
# No tenant ID - skip rate limiting for auth/public endpoints
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
# Get subscription tier from headers (added by AuthMiddleware)
|
||||
subscription_tier = request.headers.get("x-subscription-tier")
|
||||
|
||||
if not subscription_tier:
|
||||
# Fallback: get from request state if headers not available
|
||||
subscription_tier = getattr(request.state, "subscription_tier", None)
|
||||
|
||||
if not subscription_tier:
|
||||
# Final fallback: get from tenant service (should rarely happen)
|
||||
subscription_tier = await self._get_subscription_tier(tenant_id, request)
|
||||
logger.warning(f"Subscription tier not found in headers or state, fetched from tenant service: {subscription_tier}")
|
||||
|
||||
# Get quota limit for tier
|
||||
quota_limit = self._get_quota_limit(subscription_tier)
|
||||
|
||||
# Check and increment quota
|
||||
allowed, current_count = await self._check_and_increment_quota(
|
||||
tenant_id,
|
||||
quota_limit
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
"API rate limit exceeded",
|
||||
tenant_id=tenant_id,
|
||||
subscription_tier=subscription_tier,
|
||||
current_count=current_count,
|
||||
quota_limit=quota_limit,
|
||||
path=request.url.path
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail={
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": f"API rate limit exceeded. Maximum {quota_limit} calls per hour allowed for {subscription_tier} plan.",
|
||||
"current_count": current_count,
|
||||
"quota_limit": quota_limit,
|
||||
"reset_time": self._get_reset_time(),
|
||||
"upgrade_required": subscription_tier in ['starter', 'professional']
|
||||
}
|
||||
)
|
||||
|
||||
# Add rate limit headers to response
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit"] = str(quota_limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(max(0, quota_limit - current_count))
|
||||
response.headers["X-RateLimit-Reset"] = self._get_reset_time()
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Rate limiting check failed, allowing request",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
path=request.url.path
|
||||
)
|
||||
# Fail open - allow request if rate limiting fails
|
||||
return await call_next(request)
|
||||
|
||||
def _should_skip_rate_limit(self, path: str) -> bool:
|
||||
"""
|
||||
Determine if path should skip rate limiting.
|
||||
"""
|
||||
skip_paths = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/",
|
||||
"/api/v1/plans", # Public pricing info
|
||||
]
|
||||
|
||||
for skip_path in skip_paths:
|
||||
if path.startswith(skip_path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _extract_tenant_id(self, request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract tenant ID from request headers or path.
|
||||
"""
|
||||
# Try header first
|
||||
tenant_id = request.headers.get("x-tenant-id")
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Try to extract from path /api/v1/tenants/{tenant_id}/...
|
||||
path_parts = request.url.path.split("/")
|
||||
if "tenants" in path_parts:
|
||||
try:
|
||||
tenant_index = path_parts.index("tenants")
|
||||
if len(path_parts) > tenant_index + 1:
|
||||
return path_parts[tenant_index + 1]
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def _get_subscription_tier(self, tenant_id: str, request: Request) -> str:
|
||||
"""
|
||||
Get subscription tier from tenant service (with caching).
|
||||
"""
|
||||
try:
|
||||
# Try to get from request state (if subscription middleware already ran)
|
||||
if hasattr(request.state, "subscription_tier"):
|
||||
return request.state.subscription_tier
|
||||
|
||||
# Call tenant service to get tier
|
||||
import httpx
|
||||
from gateway.app.core.config import settings
|
||||
|
||||
async with httpx.AsyncClient(timeout=2.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
|
||||
headers={
|
||||
"x-service": "gateway"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data.get("tier", "starter")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to get subscription tier, defaulting to starter",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
return "starter"
|
||||
|
||||
def _get_quota_limit(self, subscription_tier: str) -> int:
|
||||
"""
|
||||
Get API calls per hour quota for subscription tier.
|
||||
"""
|
||||
quota_map = {
|
||||
"starter": 100,
|
||||
"professional": 1000,
|
||||
"enterprise": 10000,
|
||||
"demo": 1000, # Same as professional
|
||||
}
|
||||
|
||||
return quota_map.get(subscription_tier.lower(), 100)
|
||||
|
||||
async def _check_and_increment_quota(
|
||||
self,
|
||||
tenant_id: str,
|
||||
quota_limit: int
|
||||
) -> tuple[bool, int]:
|
||||
"""
|
||||
Check current quota usage and increment counter.
|
||||
|
||||
Returns:
|
||||
(allowed: bool, current_count: int)
|
||||
"""
|
||||
if not self.redis_client:
|
||||
# No Redis - fail open
|
||||
return True, 0
|
||||
|
||||
try:
|
||||
# Create hourly bucket key
|
||||
current_hour = datetime.now(timezone.utc).strftime("%Y-%m-%d-%H")
|
||||
quota_key = f"quota:hourly:api_calls:{tenant_id}:{current_hour}"
|
||||
|
||||
# Get current count
|
||||
current_count = await self.redis_client.get(quota_key)
|
||||
current_count = int(current_count) if current_count else 0
|
||||
|
||||
# Check if within limit
|
||||
if current_count >= quota_limit:
|
||||
return False, current_count
|
||||
|
||||
# Increment counter
|
||||
new_count = await self.redis_client.incr(quota_key)
|
||||
|
||||
# Set expiry (1 hour + 5 minutes buffer)
|
||||
await self.redis_client.expire(quota_key, 3900)
|
||||
|
||||
return True, new_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Redis quota check failed",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
# Fail open
|
||||
return True, 0
|
||||
|
||||
def _get_reset_time(self) -> str:
|
||||
"""
|
||||
Get the reset time for the current hour bucket (top of next hour).
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
next_hour = (now + timedelta(hours=1)).replace(minute=0, second=0, microsecond=0)
|
||||
|
||||
return next_hour.isoformat()
|
||||
|
||||
|
||||
async def get_rate_limit_middleware(app):
|
||||
"""
|
||||
Factory function to create rate limiting middleware with Redis client.
|
||||
"""
|
||||
try:
|
||||
from gateway.app.core.config import settings
|
||||
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
||||
|
||||
logger.info("API rate limiting middleware initialized with Redis")
|
||||
return APIRateLimitMiddleware(app, redis_client=redis_client)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to initialize Redis for rate limiting, middleware will fail open",
|
||||
error=str(e)
|
||||
)
|
||||
return APIRateLimitMiddleware(app, redis_client=None)
|
||||
149
gateway/app/middleware/read_only_mode.py
Normal file
149
gateway/app/middleware/read_only_mode.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Gateway middleware to enforce read-only mode for subscriptions with status:
|
||||
- pending_cancellation (until cancellation_effective_date)
|
||||
- inactive (after cancellation or no active subscription)
|
||||
|
||||
Allowed operations in read-only mode:
|
||||
- GET requests (all read operations)
|
||||
- POST /api/v1/users/me/delete/request (account deletion)
|
||||
- POST /api/v1/subscriptions/reactivate (subscription reactivation)
|
||||
- POST /api/v1/subscriptions/* (subscription management)
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import logging
|
||||
from fastapi import Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from typing import Optional
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Whitelist of POST/PUT/DELETE endpoints allowed in read-only mode
|
||||
READ_ONLY_WHITELIST_PATTERNS = [
|
||||
r'^/api/v1/users/me/delete/request$',
|
||||
r'^/api/v1/users/me/export.*$',
|
||||
r'^/api/v1/tenants/.*/subscription/.*', # All tenant subscription endpoints
|
||||
r'^/api/v1/registration/.*', # Registration flow endpoints
|
||||
r'^/api/v1/auth/.*', # Allow auth operations
|
||||
r'^/api/v1/tenants/register$', # Allow new tenant registration (no existing tenant context)
|
||||
r'^/api/v1/tenants/.*/orchestrator/run-daily-workflow$', # Allow workflow testing
|
||||
r'^/api/v1/tenants/.*/inventory/ml/insights/.*', # Allow ML insights (safety stock optimization)
|
||||
r'^/api/v1/tenants/.*/production/ml/insights/.*', # Allow ML insights (yield prediction)
|
||||
r'^/api/v1/tenants/.*/procurement/ml/insights/.*', # Allow ML insights (supplier analysis, price forecasting)
|
||||
r'^/api/v1/tenants/.*/forecasting/ml/insights/.*', # Allow ML insights (rules generation)
|
||||
r'^/api/v1/tenants/.*/forecasting/operations/.*', # Allow forecasting operations
|
||||
r'^/api/v1/webhooks/.*', # Webhook endpoints - no tenant context
|
||||
]
|
||||
|
||||
|
||||
class ReadOnlyModeMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to enforce read-only mode based on subscription status
|
||||
"""
|
||||
|
||||
def __init__(self, app, tenant_service_url: str = "http://tenant-service:8000"):
|
||||
super().__init__(app)
|
||||
self.tenant_service_url = tenant_service_url
|
||||
self.cache = {}
|
||||
self.cache_ttl = 60
|
||||
|
||||
async def check_subscription_status(self, tenant_id: str, authorization: str) -> dict:
|
||||
"""
|
||||
Check subscription status from tenant service
|
||||
Returns subscription data including status and read_only flag
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.tenant_service_url}/api/v1/tenants/{tenant_id}/subscription/status",
|
||||
headers={"Authorization": authorization}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
elif response.status_code == 404:
|
||||
return {"status": "inactive", "is_read_only": True}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to check subscription status: {response.status_code}",
|
||||
extra={"tenant_id": tenant_id}
|
||||
)
|
||||
return {"status": "unknown", "is_read_only": False}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error checking subscription status: {e}",
|
||||
extra={"tenant_id": tenant_id}
|
||||
)
|
||||
return {"status": "unknown", "is_read_only": False}
|
||||
|
||||
def is_whitelisted_endpoint(self, path: str) -> bool:
|
||||
"""
|
||||
Check if endpoint is whitelisted for read-only mode
|
||||
"""
|
||||
for pattern in READ_ONLY_WHITELIST_PATTERNS:
|
||||
if re.match(pattern, path):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_write_operation(self, method: str) -> bool:
|
||||
"""
|
||||
Determine if HTTP method is a write operation
|
||||
"""
|
||||
return method.upper() in ['POST', 'PUT', 'DELETE', 'PATCH']
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
Process each request through read-only mode check
|
||||
"""
|
||||
tenant_id = request.headers.get("X-Tenant-ID")
|
||||
authorization = request.headers.get("Authorization")
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
|
||||
if not tenant_id or not authorization:
|
||||
return await call_next(request)
|
||||
|
||||
if method.upper() == 'GET':
|
||||
return await call_next(request)
|
||||
|
||||
if self.is_whitelisted_endpoint(path):
|
||||
return await call_next(request)
|
||||
|
||||
if self.is_write_operation(method):
|
||||
subscription_data = await self.check_subscription_status(tenant_id, authorization)
|
||||
|
||||
if subscription_data.get("is_read_only", False):
|
||||
status_detail = subscription_data.get("status", "inactive")
|
||||
effective_date = subscription_data.get("cancellation_effective_date")
|
||||
|
||||
error_message = {
|
||||
"detail": "Account is in read-only mode",
|
||||
"reason": f"Subscription status: {status_detail}",
|
||||
"message": "Your subscription has been cancelled. You can view data but cannot make changes.",
|
||||
"action_required": "Reactivate your subscription to regain full access",
|
||||
"reactivation_url": "/app/settings/subscription"
|
||||
}
|
||||
|
||||
if effective_date:
|
||||
error_message["read_only_until"] = effective_date
|
||||
error_message["message"] = f"Your subscription is pending cancellation. Read-only mode starts on {effective_date}."
|
||||
|
||||
logger.info(
|
||||
"read_only_mode_enforced",
|
||||
extra={
|
||||
"tenant_id": tenant_id,
|
||||
"path": path,
|
||||
"method": method,
|
||||
"subscription_status": status_detail
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content=error_message
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
83
gateway/app/middleware/request_id.py
Normal file
83
gateway/app/middleware/request_id.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Request ID Middleware for distributed tracing
|
||||
Generates and propagates unique request IDs across all services
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import structlog
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to generate and propagate request IDs for distributed tracing.
|
||||
|
||||
Request IDs are:
|
||||
- Generated if not provided by client
|
||||
- Logged with every request
|
||||
- Propagated to all downstream services
|
||||
- Returned in response headers
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with request ID tracking"""
|
||||
|
||||
# Extract or generate request ID
|
||||
request_id = request.headers.get("X-Request-ID")
|
||||
if not request_id:
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
# Store in request state for access by routes
|
||||
request.state.request_id = request_id
|
||||
|
||||
# Bind request ID to structured logger context
|
||||
logger_ctx = logger.bind(request_id=request_id)
|
||||
|
||||
# Inject request ID header for downstream services using HeaderManager
|
||||
# Note: This runs early in middleware chain, so we use add_header_for_middleware
|
||||
header_manager.add_header_for_middleware(request, "x-request-id", request_id)
|
||||
|
||||
# Log request start
|
||||
logger_ctx.info(
|
||||
"Request started",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
client_ip=request.client.host if request.client else None
|
||||
)
|
||||
|
||||
try:
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add request ID to response headers
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
|
||||
# Log request completion
|
||||
logger_ctx.info(
|
||||
"Request completed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
status_code=response.status_code
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Log request failure
|
||||
logger_ctx.error(
|
||||
"Request failed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
raise
|
||||
462
gateway/app/middleware/subscription.py
Normal file
462
gateway/app/middleware/subscription.py
Normal file
@@ -0,0 +1,462 @@
|
||||
"""
|
||||
Subscription Middleware - Enforces subscription limits and feature access
|
||||
Updated to support standardized URL structure with tier-based access control
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
import structlog
|
||||
from fastapi import Request, Response, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional, List
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from app.utils.subscription_error_responses import create_upgrade_required_response
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SubscriptionMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to enforce subscription-based access control
|
||||
|
||||
Supports standardized URL structure:
|
||||
- Base routes (/api/v1/tenants/{tenant_id}/{service}/{resource}): ALL tiers
|
||||
- Dashboard routes (/api/v1/tenants/{tenant_id}/{service}/dashboard/*): ALL tiers
|
||||
- Analytics routes (/api/v1/tenants/{tenant_id}/{service}/analytics/*): PROFESSIONAL+
|
||||
- Operations routes (/api/v1/tenants/{tenant_id}/{service}/operations/*): ALL tiers (role-based)
|
||||
"""
|
||||
|
||||
def __init__(self, app, tenant_service_url: str, redis_client=None):
|
||||
super().__init__(app)
|
||||
self.tenant_service_url = tenant_service_url.rstrip('/')
|
||||
self.redis_client = redis_client # Optional Redis client for abuse detection
|
||||
|
||||
# Define route patterns that require subscription validation
|
||||
# Using new standardized URL structure
|
||||
self.protected_routes = {
|
||||
# ===== ANALYTICS ROUTES - PROFESSIONAL/ENTERPRISE ONLY =====
|
||||
# Any service analytics endpoint
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/analytics/.*': {
|
||||
'feature': 'analytics',
|
||||
'minimum_tier': 'professional',
|
||||
'allowed_tiers': ['professional', 'enterprise'],
|
||||
'description': 'Analytics features (Professional/Enterprise only)'
|
||||
},
|
||||
|
||||
# ===== TRAINING SERVICE - ALL TIERS =====
|
||||
r'^/api/v1/tenants/[^/]+/training/.*': {
|
||||
'feature': 'ml_training',
|
||||
'minimum_tier': 'basic',
|
||||
'allowed_tiers': ['basic', 'professional', 'enterprise'],
|
||||
'description': 'Machine learning model training (Available for all tiers)'
|
||||
},
|
||||
|
||||
# ===== ADVANCED FEATURES - PROFESSIONAL/ENTERPRISE =====
|
||||
# Advanced reporting and exports
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/export/advanced.*': {
|
||||
'feature': 'advanced_exports',
|
||||
'minimum_tier': 'professional',
|
||||
'allowed_tiers': ['professional', 'enterprise'],
|
||||
'description': 'Advanced export formats (Professional/Enterprise only)'
|
||||
},
|
||||
|
||||
# Bulk operations
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/bulk/.*': {
|
||||
'feature': 'bulk_operations',
|
||||
'minimum_tier': 'professional',
|
||||
'allowed_tiers': ['professional', 'enterprise'],
|
||||
'description': 'Bulk operations (Professional/Enterprise only)'
|
||||
},
|
||||
}
|
||||
|
||||
# Routes that are explicitly allowed for all tiers (no check needed)
|
||||
self.public_tier_routes = [
|
||||
# Base CRUD operations - ALL TIERS
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/?$',
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/(?!analytics|export/advanced|bulk)[^/]+/[^/]+/?$',
|
||||
|
||||
# Dashboard routes - ALL TIERS
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/dashboard/.*',
|
||||
|
||||
# Operations routes - ALL TIERS (role-based control applies)
|
||||
r'^/api/v1/tenants/[^/]+/[^/]+/operations/.*',
|
||||
]
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process the request and check subscription requirements"""
|
||||
|
||||
# Skip subscription check for certain routes
|
||||
if self._should_skip_subscription_check(request):
|
||||
return await call_next(request)
|
||||
|
||||
# Check if route is explicitly allowed for all tiers
|
||||
if self._is_public_tier_route(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Check if route requires subscription validation
|
||||
subscription_requirement = self._get_subscription_requirement(request.url.path)
|
||||
if not subscription_requirement:
|
||||
return await call_next(request)
|
||||
|
||||
# Get tenant ID from request
|
||||
tenant_id = self._extract_tenant_id(request)
|
||||
if not tenant_id:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error": "subscription_validation_failed",
|
||||
"message": "Tenant ID required for subscription validation",
|
||||
"code": "MISSING_TENANT_ID"
|
||||
}
|
||||
)
|
||||
|
||||
# Validate subscription with new tier-based system
|
||||
validation_result = await self._validate_subscription_tier(
|
||||
request,
|
||||
tenant_id,
|
||||
subscription_requirement.get('feature'),
|
||||
subscription_requirement.get('minimum_tier'),
|
||||
subscription_requirement.get('allowed_tiers', [])
|
||||
)
|
||||
|
||||
if not validation_result['allowed']:
|
||||
# Use enhanced error response with conversion optimization
|
||||
feature = subscription_requirement.get('feature')
|
||||
current_tier = validation_result.get('current_tier', 'unknown')
|
||||
required_tier = subscription_requirement.get('minimum_tier')
|
||||
allowed_tiers = subscription_requirement.get('allowed_tiers', [])
|
||||
|
||||
# Create conversion-optimized error response
|
||||
enhanced_response = create_upgrade_required_response(
|
||||
feature=feature,
|
||||
current_tier=current_tier,
|
||||
required_tier=required_tier,
|
||||
allowed_tiers=allowed_tiers,
|
||||
custom_message=validation_result.get('message')
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=enhanced_response.status_code,
|
||||
content=enhanced_response.dict()
|
||||
)
|
||||
|
||||
# Subscription validation passed, continue with request
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
def _is_public_tier_route(self, path: str) -> bool:
|
||||
"""
|
||||
Check if route is explicitly allowed for all subscription tiers
|
||||
|
||||
Args:
|
||||
path: Request path
|
||||
|
||||
Returns:
|
||||
True if route is allowed for all tiers
|
||||
"""
|
||||
for pattern in self.public_tier_routes:
|
||||
if re.match(pattern, path):
|
||||
logger.debug("Route allowed for all tiers", path=path, pattern=pattern)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _should_skip_subscription_check(self, request: Request) -> bool:
|
||||
"""Check if subscription validation should be skipped"""
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
|
||||
# Skip for health checks, auth, and public routes
|
||||
skip_patterns = [
|
||||
r'/health.*',
|
||||
r'/metrics.*',
|
||||
r'/api/v1/auth/.*',
|
||||
r'/api/v1/tenants/[^/]+/subscription/.*', # All tenant subscription endpoints
|
||||
r'/api/v1/registration/.*', # Registration flow endpoints
|
||||
r'/api/v1/tenants/[^/]+/members.*', # Basic tenant info
|
||||
r'/api/v1/webhooks/.*', # Webhook endpoints - no tenant context
|
||||
r'/docs.*',
|
||||
r'/openapi\.json',
|
||||
# Training monitoring endpoints (WebSocket and status checks)
|
||||
r'/api/v1/tenants/[^/]+/training/jobs/.*/live.*', # WebSocket endpoint
|
||||
r'/api/v1/tenants/[^/]+/training/jobs/.*/status.*', # Status polling endpoint
|
||||
]
|
||||
|
||||
# Skip OPTIONS requests (CORS preflight)
|
||||
if method == "OPTIONS":
|
||||
return True
|
||||
|
||||
for pattern in skip_patterns:
|
||||
if re.match(pattern, path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_subscription_requirement(self, path: str) -> Optional[Dict[str, str]]:
|
||||
"""Get subscription requirement for a given path"""
|
||||
for pattern, requirement in self.protected_routes.items():
|
||||
if re.match(pattern, path):
|
||||
return requirement
|
||||
return None
|
||||
|
||||
def _extract_tenant_id(self, request: Request) -> Optional[str]:
|
||||
"""Extract tenant ID from request"""
|
||||
# Try to get from URL path first
|
||||
path_match = re.search(r'/api/v1/tenants/([^/]+)/', request.url.path)
|
||||
if path_match:
|
||||
return path_match.group(1)
|
||||
|
||||
# Try to get from headers
|
||||
tenant_id = request.headers.get('x-tenant-id')
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Try to get from user state (set by auth middleware)
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
return request.state.user.get('tenant_id')
|
||||
|
||||
return None
|
||||
|
||||
async def _validate_subscription_tier(
|
||||
self,
|
||||
request: Request,
|
||||
tenant_id: str,
|
||||
feature: Optional[str],
|
||||
minimum_tier: str,
|
||||
allowed_tiers: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate subscription tier access using cached subscription lookup
|
||||
|
||||
Args:
|
||||
request: FastAPI request
|
||||
tenant_id: Tenant ID
|
||||
feature: Feature name (optional, for additional checks)
|
||||
minimum_tier: Minimum required subscription tier
|
||||
allowed_tiers: List of allowed subscription tiers
|
||||
|
||||
Returns:
|
||||
Dict with 'allowed' boolean and additional metadata
|
||||
"""
|
||||
try:
|
||||
# Check if JWT already has subscription
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
user_context = request.state.user
|
||||
user_id = user_context.get('user_id', 'unknown')
|
||||
|
||||
if user_context.get("subscription_from_jwt"):
|
||||
# Use JWT data directly - NO HTTP CALL!
|
||||
current_tier = user_context.get("subscription_tier", "starter")
|
||||
|
||||
logger.debug("Using subscription tier from JWT (no HTTP call)",
|
||||
tenant_id=tenant_id,
|
||||
current_tier=current_tier,
|
||||
minimum_tier=minimum_tier,
|
||||
allowed_tiers=allowed_tiers)
|
||||
|
||||
if current_tier not in [tier.lower() for tier in allowed_tiers]:
|
||||
tier_names = ', '.join(allowed_tiers)
|
||||
return {
|
||||
'allowed': False,
|
||||
'message': f'This feature requires a {tier_names} subscription plan',
|
||||
'current_tier': current_tier
|
||||
}
|
||||
await self._log_subscription_access(
|
||||
tenant_id,
|
||||
user_id,
|
||||
feature,
|
||||
current_tier,
|
||||
True,
|
||||
"jwt"
|
||||
)
|
||||
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (JWT subscription)',
|
||||
'current_tier': current_tier
|
||||
}
|
||||
|
||||
# Use unified HeaderManager for consistent header handling
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Extract user_id for logging (fallback path)
|
||||
user_id = header_manager.get_header_value(request, 'x-user-id', 'unknown')
|
||||
|
||||
# Call tenant service fast tier endpoint with caching (fallback for old tokens)
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=1.0, # Connection timeout - very short for cached endpoint
|
||||
read=5.0, # Read timeout - short for cached lookup
|
||||
write=1.0, # Write timeout
|
||||
pool=1.0 # Pool timeout
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||
# Use fast cached tier endpoint (new URL pattern)
|
||||
tier_response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/subscription/tier",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
if tier_response.status_code != 200:
|
||||
logger.warning(
|
||||
"Failed to get subscription tier from cache",
|
||||
tenant_id=tenant_id,
|
||||
status_code=tier_response.status_code,
|
||||
response_text=tier_response.text
|
||||
)
|
||||
# Fail open for availability
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (validation service unavailable)',
|
||||
'current_tier': 'unknown'
|
||||
}
|
||||
|
||||
tier_data = tier_response.json()
|
||||
current_tier = tier_data.get('tier', 'starter').lower()
|
||||
|
||||
logger.debug("Subscription tier check (cached)",
|
||||
tenant_id=tenant_id,
|
||||
current_tier=current_tier,
|
||||
minimum_tier=minimum_tier,
|
||||
allowed_tiers=allowed_tiers,
|
||||
cached=tier_data.get('cached', False))
|
||||
|
||||
# Check if current tier is in allowed tiers
|
||||
if current_tier not in [tier.lower() for tier in allowed_tiers]:
|
||||
tier_names = ', '.join(allowed_tiers)
|
||||
await self._log_subscription_access(
|
||||
tenant_id,
|
||||
user_id,
|
||||
feature,
|
||||
current_tier,
|
||||
False,
|
||||
"jwt"
|
||||
)
|
||||
|
||||
return {
|
||||
'allowed': False,
|
||||
'message': f'This feature requires a {tier_names} subscription plan',
|
||||
'current_tier': current_tier
|
||||
}
|
||||
|
||||
# Tier check passed
|
||||
await self._log_subscription_access(
|
||||
tenant_id,
|
||||
user_id,
|
||||
feature,
|
||||
current_tier,
|
||||
True,
|
||||
"database"
|
||||
)
|
||||
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted',
|
||||
'current_tier': current_tier
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"Timeout validating subscription",
|
||||
tenant_id=tenant_id,
|
||||
feature=feature
|
||||
)
|
||||
# Fail open for availability (let service handle detailed check)
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (validation timeout)',
|
||||
'current_plan': 'unknown'
|
||||
}
|
||||
except httpx.RequestError as e:
|
||||
logger.error(
|
||||
"Request error validating subscription",
|
||||
tenant_id=tenant_id,
|
||||
feature=feature,
|
||||
error=str(e)
|
||||
)
|
||||
# Fail open for availability
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (validation service unavailable)',
|
||||
'current_plan': 'unknown'
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Subscription validation error",
|
||||
tenant_id=tenant_id,
|
||||
feature=feature,
|
||||
error=str(e)
|
||||
)
|
||||
# Fail open for availability (let service handle detailed check)
|
||||
return {
|
||||
'allowed': True,
|
||||
'message': 'Access granted (validation error)',
|
||||
'current_plan': 'unknown'
|
||||
}
|
||||
|
||||
async def _log_subscription_access(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
requested_feature: str,
|
||||
current_tier: str,
|
||||
access_granted: bool,
|
||||
source: str # "jwt" or "database"
|
||||
):
|
||||
"""
|
||||
Log all subscription-gated access attempts for audit and anomaly detection.
|
||||
"""
|
||||
logger.info(
|
||||
"Subscription access check",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
feature=requested_feature,
|
||||
tier=current_tier,
|
||||
granted=access_granted,
|
||||
source=source,
|
||||
timestamp=datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
|
||||
# For denied access, check for suspicious patterns
|
||||
if not access_granted:
|
||||
await self._check_for_abuse_patterns(tenant_id, user_id, requested_feature)
|
||||
|
||||
async def _check_for_abuse_patterns(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
feature: str
|
||||
):
|
||||
"""
|
||||
Detect potential abuse patterns like repeated premium feature access attempts.
|
||||
"""
|
||||
if not self.redis_client:
|
||||
return
|
||||
|
||||
# Track denied attempts in a sliding window
|
||||
key = f"subscription_denied:{tenant_id}:{user_id}:{feature}"
|
||||
|
||||
try:
|
||||
attempts = await self.redis_client.incr(key)
|
||||
if attempts == 1:
|
||||
await self.redis_client.expire(key, 3600) # 1 hour window
|
||||
|
||||
# Alert if too many denied attempts (potential bypass attempt)
|
||||
if attempts > 10:
|
||||
logger.warning(
|
||||
"SECURITY: Excessive premium feature access attempts detected",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
feature=feature,
|
||||
attempts=attempts,
|
||||
window="1 hour"
|
||||
)
|
||||
# Could trigger alert to security team here
|
||||
except Exception as e:
|
||||
logger.warning("Failed to track abuse patterns", error=str(e))
|
||||
|
||||
0
gateway/app/routes/__init__.py
Normal file
0
gateway/app/routes/__init__.py
Normal file
240
gateway/app/routes/auth.py
Normal file
240
gateway/app/routes/auth.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# ================================================================
|
||||
# gateway/app/routes/auth.py
|
||||
# ================================================================
|
||||
"""
|
||||
Authentication and User Management Routes for API Gateway
|
||||
Unified proxy to auth microservice
|
||||
"""
|
||||
|
||||
import logging
|
||||
import httpx
|
||||
from fastapi import APIRouter, Request, Response, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Dict, Any
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from app.core.service_discovery import ServiceDiscovery
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize service discovery and metrics
|
||||
service_discovery = ServiceDiscovery()
|
||||
metrics = MetricsCollector("gateway")
|
||||
|
||||
# Register custom metrics for auth routes
|
||||
metrics.register_counter("gateway_auth_requests_total", "Total authentication requests through gateway")
|
||||
metrics.register_counter("gateway_auth_responses_total", "Total authentication responses from gateway")
|
||||
metrics.register_counter("gateway_auth_errors_total", "Total authentication errors in gateway")
|
||||
|
||||
# Auth service configuration
|
||||
AUTH_SERVICE_URL = settings.AUTH_SERVICE_URL or "http://auth-service:8000"
|
||||
|
||||
|
||||
class AuthProxy:
|
||||
"""Authentication service proxy with enhanced error handling"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(30.0),
|
||||
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20)
|
||||
)
|
||||
|
||||
async def forward_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
request: Request
|
||||
) -> Response:
|
||||
"""Forward request to auth service with proper error handling"""
|
||||
|
||||
# Handle OPTIONS requests directly for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
# Get auth service URL (with service discovery if available)
|
||||
auth_url = await self._get_auth_service_url()
|
||||
target_url = f"{auth_url}/{path}"
|
||||
|
||||
# Prepare headers (remove hop-by-hop headers)
|
||||
# IMPORTANT: Use request.headers directly to get headers added by middleware
|
||||
# Also check request.state for headers injected by middleware
|
||||
headers = self._prepare_headers(request.headers, request)
|
||||
|
||||
# Get request body
|
||||
body = await request.body()
|
||||
|
||||
# Forward request
|
||||
logger.info(f"Forwarding {method} /{path} to auth service")
|
||||
|
||||
response = await self.client.request(
|
||||
method=method,
|
||||
url=target_url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
params=dict(request.query_params)
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
metrics.increment_counter("gateway_auth_requests_total")
|
||||
metrics.increment_counter(
|
||||
"gateway_auth_responses_total",
|
||||
labels={"status_code": str(response.status_code)}
|
||||
)
|
||||
|
||||
# Prepare response headers
|
||||
response_headers = self._prepare_response_headers(dict(response.headers))
|
||||
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=response_headers,
|
||||
media_type=response.headers.get("content-type")
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error(f"Timeout forwarding {method} /{path} to auth service")
|
||||
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "timeout"})
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="Authentication service timeout"
|
||||
)
|
||||
|
||||
except httpx.ConnectError:
|
||||
logger.error(f"Connection error forwarding {method} /{path} to auth service")
|
||||
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "connection"})
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication service unavailable"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding {method} /{path} to auth service: {e}")
|
||||
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "unknown"})
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal gateway error"
|
||||
)
|
||||
|
||||
async def _get_auth_service_url(self) -> str:
|
||||
"""Get auth service URL with service discovery"""
|
||||
try:
|
||||
# Try service discovery first
|
||||
service_url = await service_discovery.get_service_url("auth-service")
|
||||
if service_url:
|
||||
return service_url
|
||||
except Exception as e:
|
||||
logger.warning(f"Service discovery failed: {e}")
|
||||
|
||||
# Fall back to configured URL
|
||||
return AUTH_SERVICE_URL
|
||||
|
||||
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
|
||||
"""Prepare headers for forwarding using unified HeaderManager"""
|
||||
# Use unified HeaderManager to get all headers
|
||||
if request:
|
||||
all_headers = header_manager.get_all_headers_for_proxy(request)
|
||||
logger.debug(f"DEBUG: Added headers from HeaderManager: {list(all_headers.keys())}")
|
||||
else:
|
||||
# Fallback: convert headers to dict manually
|
||||
all_headers = {}
|
||||
if hasattr(headers, '_list'):
|
||||
for k, v in headers.__dict__.get('_list', []):
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
elif hasattr(headers, 'raw'):
|
||||
for k, v in headers.raw:
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
else:
|
||||
# Headers is already a dict
|
||||
all_headers = dict(headers)
|
||||
|
||||
# Debug logging
|
||||
logger.info(f"📤 Forwarding headers - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
|
||||
|
||||
return all_headers
|
||||
|
||||
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Prepare response headers"""
|
||||
# Remove server-specific headers
|
||||
filtered_headers = {
|
||||
k: v for k, v in headers.items()
|
||||
if k.lower() not in {'server', 'date'}
|
||||
}
|
||||
|
||||
# Add CORS headers if needed
|
||||
if settings.CORS_ORIGINS:
|
||||
filtered_headers['Access-Control-Allow-Origin'] = '*'
|
||||
filtered_headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
|
||||
filtered_headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
|
||||
|
||||
return filtered_headers
|
||||
|
||||
|
||||
# Initialize proxy
|
||||
auth_proxy = AuthProxy()
|
||||
|
||||
# ================================================================
|
||||
# CATCH-ALL ROUTE for all auth and user endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
||||
async def proxy_auth_requests(path: str, request: Request):
|
||||
"""Catch-all proxy for all auth and user requests"""
|
||||
return await auth_proxy.forward_request(request.method, f"api/v1/auth/{path}", request)
|
||||
|
||||
# ================================================================
|
||||
# HEALTH CHECK for auth service
|
||||
# ================================================================
|
||||
|
||||
@router.get("/health")
|
||||
async def auth_service_health():
|
||||
"""Check auth service health"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(f"{AUTH_SERVICE_URL}/health")
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
"status": "healthy",
|
||||
"auth_service": "available",
|
||||
"response_time_ms": response.elapsed.total_seconds() * 1000
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"auth_service": "error",
|
||||
"status_code": response.status_code
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Auth service health check failed: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"auth_service": "unavailable",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# ================================================================
|
||||
# CLEANUP
|
||||
# ================================================================
|
||||
|
||||
@router.on_event("shutdown")
|
||||
async def cleanup():
|
||||
"""Cleanup resources"""
|
||||
await auth_proxy.client.aclose()
|
||||
58
gateway/app/routes/demo.py
Normal file
58
gateway/app/routes/demo.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Demo Session Routes - Proxy to demo-session service
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import structlog
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.api_route("/demo/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
||||
async def proxy_demo_service(path: str, request: Request):
|
||||
"""
|
||||
Proxy all demo requests to the demo-session service
|
||||
These endpoints are public and don't require authentication
|
||||
"""
|
||||
# Build the target URL
|
||||
demo_service_url = settings.DEMO_SESSION_SERVICE_URL.rstrip('/')
|
||||
target_url = f"{demo_service_url}/api/v1/demo/{path}"
|
||||
|
||||
# Get request body
|
||||
body = None
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
headers=headers,
|
||||
params=request.query_params,
|
||||
content=body
|
||||
)
|
||||
|
||||
# Return the response
|
||||
return JSONResponse(
|
||||
content=response.json() if response.content else {},
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Failed to proxy to demo-session service", error=str(e), url=target_url)
|
||||
raise HTTPException(status_code=503, detail="Demo service unavailable")
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error proxying to demo-session service", error=str(e))
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
71
gateway/app/routes/geocoding.py
Normal file
71
gateway/app/routes/geocoding.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# gateway/app/routes/geocoding.py
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import structlog
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
|
||||
async def proxy_geocoding(request: Request, path: str):
|
||||
"""
|
||||
Proxies all geocoding requests to the External Service geocoding endpoints.
|
||||
|
||||
Forwards requests from /api/v1/geocoding/* to external-service:8000/api/v1/geocoding/*
|
||||
"""
|
||||
try:
|
||||
# Construct the external service URL
|
||||
external_url = f"{settings.EXTERNAL_SERVICE_URL}/api/v1/geocoding/{path}"
|
||||
|
||||
# Get request body for POST/PUT/PATCH
|
||||
body = None
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Make the proxied request
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=external_url,
|
||||
params=request.query_params,
|
||||
headers=headers,
|
||||
content=body
|
||||
)
|
||||
|
||||
# Return the response from external service
|
||||
return JSONResponse(
|
||||
content=response.json() if response.text else None,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
|
||||
except httpx.RequestError as exc:
|
||||
logger.error("External service geocoding request failed", error=str(exc), path=path)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Geocoding service unavailable: {exc}"
|
||||
)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.error(
|
||||
f"External service geocoding responded with error {exc.response.status_code}",
|
||||
detail=exc.response.text,
|
||||
path=path
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=exc.response.status_code,
|
||||
detail=f"Geocoding service error: {exc.response.text}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error in geocoding proxy", error=str(exc), path=path)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal server error in geocoding proxy"
|
||||
)
|
||||
61
gateway/app/routes/nominatim.py
Normal file
61
gateway/app/routes/nominatim.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# gateway/app/routes/nominatim.py
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import structlog
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/search")
|
||||
async def proxy_nominatim_search(request: Request):
|
||||
"""
|
||||
Proxies requests to the Nominatim geocoding search API.
|
||||
"""
|
||||
# Handle OPTIONS requests directly for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
# Construct the internal Nominatim URL
|
||||
# All query parameters from the client request are forwarded
|
||||
nominatim_url = f"{settings.NOMINATIM_SERVICE_URL}/nominatim/search"
|
||||
|
||||
# httpx client for making async HTTP requests
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(
|
||||
nominatim_url,
|
||||
params=request.query_params # Forward all query parameters from frontend
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
|
||||
|
||||
# Return the JSON response from Nominatim directly
|
||||
return JSONResponse(content=response.json())
|
||||
|
||||
except httpx.RequestError as exc:
|
||||
logger.error("Nominatim service request failed", error=str(exc))
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Nominatim service unavailable: {exc}"
|
||||
)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.error(f"Nominatim service responded with error {exc.response.status_code}",
|
||||
detail=exc.response.text)
|
||||
raise HTTPException(
|
||||
status_code=exc.response.status_code,
|
||||
detail=f"Nominatim service error: {exc.response.text}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error in Nominatim proxy", error=str(exc))
|
||||
raise HTTPException(status_code=500, detail="Internal server error in Nominatim proxy")
|
||||
88
gateway/app/routes/poi_context.py
Normal file
88
gateway/app/routes/poi_context.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
POI Context Proxy Router
|
||||
Forwards all POI context requests to the External Service
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import structlog
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
|
||||
async def proxy_poi_context(request: Request, path: str):
|
||||
"""
|
||||
Proxies all POI context requests to the External Service.
|
||||
|
||||
Forwards requests from /api/v1/poi-context/* to external-service:8000/api/v1/poi-context/*
|
||||
|
||||
Args:
|
||||
request: The incoming request
|
||||
path: The path after /api/v1/poi-context/
|
||||
|
||||
Returns:
|
||||
JSONResponse with the response from the external service
|
||||
|
||||
Raises:
|
||||
HTTPException: If the external service is unavailable or returns an error
|
||||
"""
|
||||
try:
|
||||
# Construct the external service URL
|
||||
external_url = f"{settings.EXTERNAL_SERVICE_URL}/poi-context/{path}"
|
||||
|
||||
logger.debug("Proxying POI context request",
|
||||
method=request.method,
|
||||
path=path,
|
||||
external_url=external_url)
|
||||
|
||||
# Get request body for POST/PUT/PATCH requests
|
||||
body = None
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Make the request to the external service
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=external_url,
|
||||
params=request.query_params,
|
||||
headers=headers,
|
||||
content=body
|
||||
)
|
||||
|
||||
logger.debug("POI context proxy response",
|
||||
status_code=response.status_code,
|
||||
path=path)
|
||||
|
||||
# Return the response from the external service
|
||||
return JSONResponse(
|
||||
content=response.json() if response.text else None,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
|
||||
except httpx.RequestError as exc:
|
||||
logger.error("External service POI request failed",
|
||||
error=str(exc),
|
||||
path=path,
|
||||
external_url=external_url)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"POI service unavailable: {exc}"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error in POI proxy",
|
||||
error=str(exc),
|
||||
path=path)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal server error in POI proxy"
|
||||
)
|
||||
89
gateway/app/routes/pos.py
Normal file
89
gateway/app/routes/pos.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
POS routes for API Gateway - Global POS endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, Response, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# ================================================================
|
||||
# GLOBAL POS ENDPOINTS (No tenant context required)
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/supported-systems", methods=["GET", "OPTIONS"])
|
||||
async def proxy_supported_systems(request: Request):
|
||||
"""Proxy supported POS systems request to POS service"""
|
||||
target_path = "/api/v1/pos/supported-systems"
|
||||
return await _proxy_to_pos_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# PROXY HELPER FUNCTIONS
|
||||
# ================================================================
|
||||
|
||||
async def _proxy_to_pos_service(request: Request, target_path: str):
|
||||
"""Proxy request to POS service"""
|
||||
|
||||
# Handle OPTIONS requests directly for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400"
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
url = f"{settings.POS_SERVICE_URL}{target_path}"
|
||||
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Add query parameters
|
||||
params = dict(request.query_params)
|
||||
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=30.0,
|
||||
read=60.0,
|
||||
write=30.0,
|
||||
pool=30.0
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=params
|
||||
)
|
||||
|
||||
# Handle different response types
|
||||
if response.headers.get("content-type", "").startswith("application/json"):
|
||||
try:
|
||||
content = response.json()
|
||||
except:
|
||||
content = {"message": "Invalid JSON response from service"}
|
||||
else:
|
||||
content = response.text
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response.status_code,
|
||||
content=content
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error proxying to POS service {target_path}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal gateway error"
|
||||
)
|
||||
116
gateway/app/routes/registration.py
Normal file
116
gateway/app/routes/registration.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Registration routes for API Gateway - Handles registration-specific endpoints
|
||||
These endpoints don't require a tenant ID and are used during the registration flow
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, Response, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# ================================================================
|
||||
# REGISTRATION ENDPOINTS - Direct routing to tenant service
|
||||
# These endpoints are called during registration before a tenant exists
|
||||
# ================================================================
|
||||
|
||||
@router.post("/registration-payment-setup")
|
||||
async def proxy_registration_payment_setup(request: Request):
|
||||
"""Proxy registration payment setup request to tenant service"""
|
||||
return await _proxy_to_tenant_service(request, "/api/v1/tenants/registration-payment-setup")
|
||||
|
||||
@router.post("/verify-and-complete-registration")
|
||||
async def proxy_verify_and_complete_registration(request: Request):
|
||||
"""Proxy verification and registration completion to tenant service"""
|
||||
return await _proxy_to_tenant_service(request, "/api/v1/tenants/verify-and-complete-registration")
|
||||
|
||||
@router.post("/payment-customers/create")
|
||||
async def proxy_registration_customer_create(request: Request):
|
||||
"""Proxy registration customer creation to tenant service"""
|
||||
return await _proxy_to_tenant_service(request, "/api/v1/payment-customers/create")
|
||||
|
||||
@router.get("/setup-intents/{setup_intent_id}/verify")
|
||||
async def proxy_registration_setup_intent_verify(request: Request, setup_intent_id: str):
|
||||
"""Proxy registration setup intent verification to tenant service"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/setup-intents/{setup_intent_id}/verify")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# PROXY HELPER FUNCTIONS
|
||||
# ================================================================
|
||||
|
||||
async def _proxy_to_tenant_service(request: Request, target_path: str):
|
||||
"""Generic proxy function with enhanced error handling"""
|
||||
|
||||
# Handle OPTIONS requests directly for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID, Stripe-Signature",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400"
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
url = f"{settings.TENANT_SERVICE_URL}{target_path}"
|
||||
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Debug logging
|
||||
logger.info(f"Forwarding registration request to {url}")
|
||||
|
||||
# Get request body if present
|
||||
body = None
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
# Add query parameters
|
||||
params = dict(request.query_params)
|
||||
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=30.0,
|
||||
read=60.0,
|
||||
write=30.0,
|
||||
pool=30.0
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
params=params
|
||||
)
|
||||
|
||||
# Handle different response types
|
||||
if response.headers.get("content-type", "").startswith("application/json"):
|
||||
try:
|
||||
content = response.json()
|
||||
except:
|
||||
content = {"message": "Invalid JSON response from service"}
|
||||
else:
|
||||
content = response.text
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response.status_code,
|
||||
content=content
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error proxying registration request to {settings.TENANT_SERVICE_URL}{target_path}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal gateway error"
|
||||
)
|
||||
317
gateway/app/routes/subscription.py
Normal file
317
gateway/app/routes/subscription.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Subscription routes for API Gateway - Direct subscription endpoints
|
||||
|
||||
New URL Pattern Architecture:
|
||||
- Registration: /registration/payment-setup, /registration/complete, /registration/state/{state_id}
|
||||
- Tenant Subscription: /tenants/{tenant_id}/subscription/*
|
||||
- Setup Intents: /setup-intents/{setup_intent_id}/verify
|
||||
- Payment Customers: /payment-customers/create
|
||||
- Plans: /plans (public)
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, Response, HTTPException, Path
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# ================================================================
|
||||
# PUBLIC ENDPOINTS (No Authentication)
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/plans", methods=["GET", "OPTIONS"])
|
||||
async def proxy_plans(request: Request):
|
||||
"""Proxy plans request to tenant service"""
|
||||
target_path = "/plans"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/plans/{tier}", methods=["GET", "OPTIONS"])
|
||||
async def proxy_plan_details(request: Request, tier: str = Path(...)):
|
||||
"""Proxy specific plan details request to tenant service"""
|
||||
target_path = f"/plans/{tier}"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/plans/{tier}/features", methods=["GET", "OPTIONS"])
|
||||
async def proxy_plan_features(request: Request, tier: str = Path(...)):
|
||||
"""Proxy plan features request to tenant service"""
|
||||
target_path = f"/plans/{tier}/features"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/plans/{tier}/limits", methods=["GET", "OPTIONS"])
|
||||
async def proxy_plan_limits(request: Request, tier: str = Path(...)):
|
||||
"""Proxy plan limits request to tenant service"""
|
||||
target_path = f"/plans/{tier}/limits"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/plans/compare", methods=["GET", "OPTIONS"])
|
||||
async def proxy_plan_compare(request: Request):
|
||||
"""Proxy plan comparison request to tenant service"""
|
||||
target_path = "/plans/compare"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# REGISTRATION FLOW ENDPOINTS (No Tenant Context)
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/registration/payment-setup", methods=["POST", "OPTIONS"])
|
||||
async def proxy_registration_payment_setup(request: Request):
|
||||
"""Proxy registration payment setup request to tenant service"""
|
||||
target_path = "/api/v1/registration/payment-setup"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/registration/complete", methods=["POST", "OPTIONS"])
|
||||
async def proxy_registration_complete(request: Request):
|
||||
"""Proxy registration completion request to tenant service"""
|
||||
target_path = "/api/v1/registration/complete"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/registration/state/{state_id}", methods=["GET", "OPTIONS"])
|
||||
async def proxy_registration_state(request: Request, state_id: str = Path(...)):
|
||||
"""Proxy registration state request to tenant service"""
|
||||
target_path = f"/api/v1/registration/state/{state_id}"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# TENANT SUBSCRIPTION STATUS ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/status", methods=["GET", "OPTIONS"])
|
||||
async def proxy_subscription_status(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy subscription status request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/status"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/details", methods=["GET", "OPTIONS"])
|
||||
async def proxy_subscription_details(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy subscription details request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/details"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/tier", methods=["GET", "OPTIONS"])
|
||||
async def proxy_subscription_tier(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy subscription tier request to tenant service (cached)"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/tier"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/limits", methods=["GET", "OPTIONS"])
|
||||
async def proxy_subscription_limits(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy subscription limits request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/limits"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/usage", methods=["GET", "OPTIONS"])
|
||||
async def proxy_subscription_usage(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy subscription usage request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/usage"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/features/{feature}", methods=["GET", "OPTIONS"])
|
||||
async def proxy_subscription_feature(request: Request, tenant_id: str = Path(...), feature: str = Path(...)):
|
||||
"""Proxy subscription feature check request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/features/{feature}"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# SUBSCRIPTION MANAGEMENT ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/cancel", methods=["POST", "OPTIONS"])
|
||||
async def proxy_subscription_cancel(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy subscription cancellation request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/cancel"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/reactivate", methods=["POST", "OPTIONS"])
|
||||
async def proxy_subscription_reactivate(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy subscription reactivation request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/reactivate"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/validate-upgrade/{new_plan}", methods=["GET", "OPTIONS"])
|
||||
async def proxy_validate_upgrade(request: Request, tenant_id: str = Path(...), new_plan: str = Path(...)):
|
||||
"""Proxy plan upgrade validation request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/validate-upgrade/{new_plan}"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/upgrade", methods=["POST", "OPTIONS"])
|
||||
async def proxy_subscription_upgrade(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy subscription upgrade request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/upgrade"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# QUOTA & LIMIT CHECK ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/limits/locations", methods=["GET", "OPTIONS"])
|
||||
async def proxy_location_limits(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy location limits check request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/limits/locations"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/limits/products", methods=["GET", "OPTIONS"])
|
||||
async def proxy_product_limits(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy product limits check request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/limits/products"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/limits/users", methods=["GET", "OPTIONS"])
|
||||
async def proxy_user_limits(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy user limits check request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/limits/users"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/limits/recipes", methods=["GET", "OPTIONS"])
|
||||
async def proxy_recipe_limits(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy recipe limits check request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/limits/recipes"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/limits/suppliers", methods=["GET", "OPTIONS"])
|
||||
async def proxy_supplier_limits(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy supplier limits check request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/limits/suppliers"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# PAYMENT MANAGEMENT ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/payment-method", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_payment_method(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy payment method request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/payment-method"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/tenants/{tenant_id}/subscription/invoices", methods=["GET", "OPTIONS"])
|
||||
async def proxy_invoices(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy invoices request to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/subscription/invoices"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# SETUP INTENT VERIFICATION
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/setup-intents/{setup_intent_id}/verify", methods=["GET", "OPTIONS"])
|
||||
async def proxy_setup_intent_verify(request: Request, setup_intent_id: str = Path(...)):
|
||||
"""Proxy SetupIntent verification request to tenant service"""
|
||||
target_path = f"/api/v1/setup-intents/{setup_intent_id}/verify"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# PAYMENT CUSTOMER MANAGEMENT
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/payment-customers/create", methods=["POST", "OPTIONS"])
|
||||
async def proxy_payment_customer_create(request: Request):
|
||||
"""Proxy payment customer creation request to tenant service"""
|
||||
target_path = "/api/v1/payment-customers/create"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# USAGE FORECAST ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/usage-forecast", methods=["GET", "OPTIONS"])
|
||||
async def proxy_usage_forecast(request: Request):
|
||||
"""Proxy usage forecast request to tenant service"""
|
||||
target_path = "/api/v1/usage-forecast"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/usage-forecast/track-usage", methods=["POST", "OPTIONS"])
|
||||
async def proxy_track_usage(request: Request):
|
||||
"""Proxy track usage request to tenant service"""
|
||||
target_path = "/api/v1/usage-forecast/track-usage"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# PROXY HELPER FUNCTIONS
|
||||
# ================================================================
|
||||
|
||||
async def _proxy_to_tenant_service(request: Request, target_path: str):
|
||||
"""Proxy request to tenant service"""
|
||||
return await _proxy_request(request, target_path, settings.TENANT_SERVICE_URL)
|
||||
|
||||
async def _proxy_request(request: Request, target_path: str, service_url: str):
|
||||
"""Generic proxy function with enhanced error handling"""
|
||||
|
||||
# Handle OPTIONS requests directly for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400"
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
url = f"{service_url}{target_path}"
|
||||
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Debug logging
|
||||
user_context = getattr(request.state, 'user', None)
|
||||
service_context = getattr(request.state, 'service', None)
|
||||
if user_context:
|
||||
logger.info(f"Forwarding subscription request to {url} with user context: user_id={user_context.get('user_id')}, email={user_context.get('email')}, subscription_tier={user_context.get('subscription_tier', 'not_set')}")
|
||||
elif service_context:
|
||||
logger.debug(f"Forwarding subscription request to {url} with service context: service_name={service_context.get('service_name')}, user_type=service")
|
||||
else:
|
||||
logger.warning(f"No user or service context available when forwarding subscription request to {url}")
|
||||
|
||||
# Get request body if present
|
||||
body = None
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
# Add query parameters
|
||||
params = dict(request.query_params)
|
||||
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=30.0,
|
||||
read=60.0,
|
||||
write=30.0,
|
||||
pool=30.0
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
params=params
|
||||
)
|
||||
|
||||
# Handle different response types
|
||||
if response.headers.get("content-type", "").startswith("application/json"):
|
||||
try:
|
||||
content = response.json()
|
||||
except:
|
||||
content = {"message": "Invalid JSON response from service"}
|
||||
else:
|
||||
content = response.text
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response.status_code,
|
||||
content=content
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error proxying subscription request to {service_url}{target_path}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal gateway error"
|
||||
)
|
||||
303
gateway/app/routes/telemetry.py
Normal file
303
gateway/app/routes/telemetry.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Telemetry routes for API Gateway - Handles frontend telemetry data
|
||||
|
||||
This module provides endpoints for:
|
||||
- Receiving OpenTelemetry traces from frontend
|
||||
- Proxying traces to Signoz OTel collector
|
||||
- Providing a secure, authenticated endpoint for frontend telemetry
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
import httpx
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from shared.monitoring.metrics import MetricsCollector, create_metrics_collector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/telemetry", tags=["telemetry"])
|
||||
|
||||
# Get Signoz OTel collector endpoint from environment or use default
|
||||
SIGNOZ_OTEL_COLLECTOR = os.getenv(
|
||||
"SIGNOZ_OTEL_COLLECTOR_URL",
|
||||
"http://signoz-otel-collector.bakery-ia.svc.cluster.local:4318"
|
||||
)
|
||||
|
||||
@router.post("/v1/traces")
|
||||
async def receive_frontend_traces(request: Request):
|
||||
"""
|
||||
Receive OpenTelemetry traces from frontend and proxy to Signoz
|
||||
|
||||
This endpoint:
|
||||
- Accepts OTLP trace data from frontend
|
||||
- Validates the request
|
||||
- Proxies to Signoz OTel collector
|
||||
- Handles errors gracefully
|
||||
"""
|
||||
|
||||
# Handle OPTIONS requests for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400"
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
# Get the trace data from the request
|
||||
body = await request.body()
|
||||
|
||||
if not body:
|
||||
logger.warning("Received empty trace data from frontend")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "Empty trace data"}
|
||||
)
|
||||
|
||||
# Log the trace reception (without sensitive data)
|
||||
logger.info(
|
||||
"Received frontend traces, content_length=%s, content_type=%s, user_agent=%s",
|
||||
len(body),
|
||||
request.headers.get("content-type"),
|
||||
request.headers.get("user-agent")
|
||||
)
|
||||
|
||||
# Forward to Signoz OTel collector
|
||||
target_url = f"{SIGNOZ_OTEL_COLLECTOR}/v1/traces"
|
||||
|
||||
# Set up headers for the Signoz collector
|
||||
forward_headers = {
|
||||
"Content-Type": request.headers.get("content-type", "application/json"),
|
||||
"User-Agent": "bakery-gateway/1.0",
|
||||
"X-Forwarded-For": request.headers.get("x-forwarded-for", "frontend"),
|
||||
"X-Tenant-ID": request.headers.get("x-tenant-id", "unknown")
|
||||
}
|
||||
|
||||
# Add authentication if configured
|
||||
signoz_auth_token = os.getenv("SIGNOZ_AUTH_TOKEN")
|
||||
if signoz_auth_token:
|
||||
forward_headers["Authorization"] = f"Bearer {signoz_auth_token}"
|
||||
|
||||
# Send to Signoz collector
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=5.0,
|
||||
read=10.0,
|
||||
write=5.0,
|
||||
pool=5.0
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||
response = await client.post(
|
||||
url=target_url,
|
||||
content=body,
|
||||
headers=forward_headers
|
||||
)
|
||||
|
||||
# Log the response from Signoz
|
||||
logger.info(
|
||||
"Forwarded traces to Signoz, signoz_status=%s, signoz_response_time=%s",
|
||||
response.status_code,
|
||||
response.elapsed.total_seconds()
|
||||
)
|
||||
|
||||
# Return success response to frontend
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"message": "Traces received and forwarded to Signoz",
|
||||
"signoz_status": response.status_code,
|
||||
"trace_count": 1 # We don't know exact count without parsing
|
||||
}
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
"Signoz collector returned error, status_code=%s, error_message=%s",
|
||||
e.response.status_code,
|
||||
str(e)
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=502,
|
||||
content={
|
||||
"error": "Signoz collector error",
|
||||
"details": str(e),
|
||||
"signoz_status": e.response.status_code
|
||||
}
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(
|
||||
"Failed to connect to Signoz collector, error=%s, collector_url=%s",
|
||||
str(e),
|
||||
SIGNOZ_OTEL_COLLECTOR
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"error": "Signoz collector unavailable",
|
||||
"details": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error processing traces, error=%s, error_type=%s",
|
||||
str(e),
|
||||
type(e).__name__
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": "Internal server error",
|
||||
"details": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
@router.post("/v1/metrics")
|
||||
async def receive_frontend_metrics(request: Request):
|
||||
"""
|
||||
Receive OpenTelemetry metrics from frontend and proxy to Signoz
|
||||
"""
|
||||
|
||||
# Handle OPTIONS requests for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400"
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
body = await request.body()
|
||||
|
||||
if not body:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "Empty metrics data"}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Received frontend metrics, content_length=%s, content_type=%s",
|
||||
len(body),
|
||||
request.headers.get("content-type")
|
||||
)
|
||||
|
||||
# Forward to Signoz OTel collector
|
||||
target_url = f"{SIGNOZ_OTEL_COLLECTOR}/v1/metrics"
|
||||
|
||||
forward_headers = {
|
||||
"Content-Type": request.headers.get("content-type", "application/json"),
|
||||
"User-Agent": "bakery-gateway/1.0",
|
||||
"X-Forwarded-For": request.headers.get("x-forwarded-for", "frontend"),
|
||||
"X-Tenant-ID": request.headers.get("x-tenant-id", "unknown")
|
||||
}
|
||||
|
||||
# Add authentication if configured
|
||||
signoz_auth_token = os.getenv("SIGNOZ_AUTH_TOKEN")
|
||||
if signoz_auth_token:
|
||||
forward_headers["Authorization"] = f"Bearer {signoz_auth_token}"
|
||||
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=5.0,
|
||||
read=10.0,
|
||||
write=5.0,
|
||||
pool=5.0
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||
response = await client.post(
|
||||
url=target_url,
|
||||
content=body,
|
||||
headers=forward_headers
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Forwarded metrics to Signoz, signoz_status=%s",
|
||||
response.status_code
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"message": "Metrics received and forwarded to Signoz",
|
||||
"signoz_status": response.status_code
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error processing metrics, error=%s",
|
||||
str(e)
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": "Internal server error",
|
||||
"details": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
@router.get("/health")
|
||||
async def telemetry_health():
|
||||
"""
|
||||
Health check endpoint for telemetry service
|
||||
"""
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"status": "healthy",
|
||||
"service": "telemetry-gateway",
|
||||
"signoz_collector": SIGNOZ_OTEL_COLLECTOR
|
||||
}
|
||||
)
|
||||
|
||||
# Initialize metrics for this module
|
||||
try:
|
||||
metrics_collector = create_metrics_collector("gateway-telemetry")
|
||||
except Exception as e:
|
||||
logger.error("Failed to create metrics collector, error=%s", str(e))
|
||||
metrics_collector = None
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize telemetry metrics on startup"""
|
||||
try:
|
||||
if metrics_collector:
|
||||
# Register telemetry-specific metrics
|
||||
metrics_collector.register_counter(
|
||||
"gateway_telemetry_traces_received",
|
||||
"Number of trace batches received from frontend"
|
||||
)
|
||||
metrics_collector.register_counter(
|
||||
"gateway_telemetry_metrics_received",
|
||||
"Number of metric batches received from frontend"
|
||||
)
|
||||
metrics_collector.register_counter(
|
||||
"gateway_telemetry_errors",
|
||||
"Number of telemetry processing errors"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Telemetry gateway initialized, signoz_collector=%s",
|
||||
SIGNOZ_OTEL_COLLECTOR
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to initialize telemetry metrics, error=%s",
|
||||
str(e)
|
||||
)
|
||||
822
gateway/app/routes/tenant.py
Normal file
822
gateway/app/routes/tenant.py
Normal file
@@ -0,0 +1,822 @@
|
||||
# gateway/app/routes/tenant.py - COMPLETELY UPDATED
|
||||
"""
|
||||
Tenant routes for API Gateway - Handles all tenant-scoped endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, Response, HTTPException, Path
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# ================================================================
|
||||
# TENANT MANAGEMENT ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.post("/register")
|
||||
async def create_tenant(request: Request):
|
||||
"""Proxy tenant creation to tenant service"""
|
||||
return await _proxy_to_tenant_service(request, "/api/v1/tenants/register")
|
||||
|
||||
|
||||
|
||||
@router.get("/{tenant_id}")
|
||||
async def get_tenant(request: Request, tenant_id: str = Path(...)):
|
||||
"""Get specific tenant details"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}")
|
||||
|
||||
@router.put("/{tenant_id}")
|
||||
async def update_tenant(request: Request, tenant_id: str = Path(...)):
|
||||
"""Update tenant details"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}")
|
||||
|
||||
@router.get("/{tenant_id}/members")
|
||||
async def get_tenant_members(request: Request, tenant_id: str = Path(...)):
|
||||
"""Get tenant members"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/members")
|
||||
|
||||
@router.post("/{tenant_id}/members")
|
||||
async def add_tenant_member(request: Request, tenant_id: str = Path(...)):
|
||||
"""Add a team member to tenant"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/members")
|
||||
|
||||
@router.post("/{tenant_id}/members/with-user")
|
||||
async def add_tenant_member_with_user(request: Request, tenant_id: str = Path(...)):
|
||||
"""Add a team member to tenant with user creation"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/members/with-user")
|
||||
|
||||
@router.put("/{tenant_id}/members/{member_user_id}/role")
|
||||
async def update_member_role(request: Request, tenant_id: str = Path(...), member_user_id: str = Path(...)):
|
||||
"""Update team member role"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/members/{member_user_id}/role")
|
||||
|
||||
@router.delete("/{tenant_id}/members/{member_user_id}")
|
||||
async def remove_tenant_member(request: Request, tenant_id: str = Path(...), member_user_id: str = Path(...)):
|
||||
"""Remove team member from tenant"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/members/{member_user_id}")
|
||||
|
||||
@router.post("/{tenant_id}/transfer-ownership")
|
||||
async def transfer_tenant_ownership(request: Request, tenant_id: str = Path(...)):
|
||||
"""Transfer tenant ownership to another admin"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/transfer-ownership")
|
||||
|
||||
@router.get("/{tenant_id}/admins")
|
||||
async def get_tenant_admins(request: Request, tenant_id: str = Path(...)):
|
||||
"""Get all admins for a tenant"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/admins")
|
||||
|
||||
@router.get("/{tenant_id}/hierarchy")
|
||||
async def get_tenant_hierarchy(request: Request, tenant_id: str = Path(...)):
|
||||
"""Get tenant hierarchy information"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/hierarchy")
|
||||
|
||||
@router.api_route("/{tenant_id}/children", methods=["GET", "OPTIONS"])
|
||||
async def get_tenant_children(request: Request, tenant_id: str = Path(...)):
|
||||
"""Get tenant children"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/children")
|
||||
|
||||
|
||||
@router.api_route("/{tenant_id}/bulk-children", methods=["POST", "OPTIONS"])
|
||||
async def proxy_bulk_children(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy bulk children creation requests to tenant service"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/bulk-children")
|
||||
|
||||
@router.api_route("/{tenant_id}/children/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_children(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant children requests to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/children/{path}".rstrip("/")
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/{tenant_id}/access/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_access(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant access requests to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/access/{path}".rstrip("/")
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.get("/{tenant_id}/my-access")
|
||||
async def get_tenant_my_access(request: Request, tenant_id: str = Path(...)):
|
||||
"""Get current user's access level for a tenant"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/my-access")
|
||||
|
||||
@router.get("/user/{user_id}")
|
||||
async def get_user_tenants(request: Request, user_id: str = Path(...)):
|
||||
"""Get all tenant memberships for a user (admin only)"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/users/{user_id}")
|
||||
|
||||
@router.get("/user/{user_id}/owned")
|
||||
async def get_user_owned_tenants(request: Request, user_id: str = Path(...)):
|
||||
"""Get all tenants owned by a user"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/user/{user_id}/owned")
|
||||
|
||||
@router.get("/user/{user_id}/tenants")
|
||||
async def get_user_all_tenants(request: Request, user_id: str = Path(...)):
|
||||
"""Get all tenants accessible by a user (both owned and member tenants)"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/user/{user_id}/tenants")
|
||||
|
||||
@router.get("/users/{user_id}/primary-tenant")
|
||||
async def get_user_primary_tenant(request: Request, user_id: str = Path(...)):
|
||||
"""Get the primary tenant for a user (used by auth service for subscription validation)"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/users/{user_id}/primary-tenant")
|
||||
|
||||
@router.delete("/user/{user_id}/memberships")
|
||||
async def delete_user_tenants(request: Request, user_id: str = Path(...)):
|
||||
"""Get all tenant memberships for a user (admin only)"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/user/{user_id}/memberships")
|
||||
|
||||
# ================================================================
|
||||
# TENANT SETTINGS ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.get("/{tenant_id}/settings")
|
||||
async def get_tenant_settings(request: Request, tenant_id: str = Path(...)):
|
||||
"""Get all settings for a tenant"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/settings")
|
||||
|
||||
@router.put("/{tenant_id}/settings")
|
||||
async def update_tenant_settings(request: Request, tenant_id: str = Path(...)):
|
||||
"""Update tenant settings"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/settings")
|
||||
|
||||
@router.get("/{tenant_id}/settings/{category}")
|
||||
async def get_category_settings(request: Request, tenant_id: str = Path(...), category: str = Path(...)):
|
||||
"""Get settings for a specific category"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/settings/{category}")
|
||||
|
||||
@router.put("/{tenant_id}/settings/{category}")
|
||||
async def update_category_settings(request: Request, tenant_id: str = Path(...), category: str = Path(...)):
|
||||
"""Update settings for a specific category"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/settings/{category}")
|
||||
|
||||
@router.post("/{tenant_id}/settings/{category}/reset")
|
||||
async def reset_category_settings(request: Request, tenant_id: str = Path(...), category: str = Path(...)):
|
||||
"""Reset a category to default values"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/settings/{category}/reset")
|
||||
|
||||
# ================================================================
|
||||
# TENANT SUBSCRIPTION ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
# NOTE: All subscription endpoints have been moved to gateway/app/routes/subscription.py
|
||||
# as part of the architecture redesign for better separation of concerns.
|
||||
# This wildcard route has been removed to avoid conflicts with the new specific routes.
|
||||
|
||||
@router.api_route("/subscriptions/plans", methods=["GET", "OPTIONS"])
|
||||
async def proxy_available_plans(request: Request):
|
||||
"""Proxy available plans request to tenant service"""
|
||||
target_path = "/api/v1/plans/available"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# BATCH OPERATIONS ENDPOINTS
|
||||
# IMPORTANT: Route order matters! Keep specific routes before wildcards:
|
||||
# 1. Exact matches first (/batch/sales-summary)
|
||||
# 2. Wildcard paths second (/batch{path:path})
|
||||
# 3. Tenant-scoped wildcards last (/{tenant_id}/batch{path:path})
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/batch/sales-summary", methods=["POST"])
|
||||
async def proxy_batch_sales_summary(request: Request):
|
||||
"""Proxy batch sales summary request to sales service"""
|
||||
target_path = "/api/v1/batch/sales-summary"
|
||||
return await _proxy_to_sales_service(request, target_path)
|
||||
|
||||
|
||||
@router.api_route("/batch/production-summary", methods=["POST"])
|
||||
async def proxy_batch_production_summary(request: Request):
|
||||
"""Proxy batch production summary request to production service"""
|
||||
target_path = "/api/v1/batch/production-summary"
|
||||
return await _proxy_to_production_service(request, target_path)
|
||||
|
||||
|
||||
@router.api_route("/batch{path:path}", methods=["POST", "GET", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_batch_operations(request: Request, path: str = ""):
|
||||
"""Proxy batch operations that span multiple tenants to appropriate services"""
|
||||
# For batch operations, route based on the path after /batch/
|
||||
if path.startswith("/sales-summary"):
|
||||
# Route batch sales summary to sales service
|
||||
# The sales service batch endpoints are at /api/v1/batch/... not /api/v1/sales/batch/...
|
||||
target_path = f"/api/v1/batch{path}"
|
||||
return await _proxy_to_sales_service(request, target_path)
|
||||
elif path.startswith("/production-summary"):
|
||||
# Route batch production summary to production service
|
||||
# The production service batch endpoints are at /api/v1/batch/... not /api/v1/production/batch/...
|
||||
target_path = f"/api/v1/batch{path}"
|
||||
return await _proxy_to_production_service(request, target_path)
|
||||
else:
|
||||
# Default to sales service for other batch operations
|
||||
# The service batch endpoints are at /api/v1/batch/... not /api/v1/sales/batch/...
|
||||
target_path = f"/api/v1/batch{path}"
|
||||
return await _proxy_to_sales_service(request, target_path)
|
||||
|
||||
|
||||
@router.api_route("/{tenant_id}/batch{path:path}", methods=["POST", "GET", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_batch_operations(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant-scoped batch operations to appropriate services"""
|
||||
# For tenant-scoped batch operations, route based on the path after /batch/
|
||||
if path.startswith("/sales-summary"):
|
||||
# Route tenant batch sales summary to sales service
|
||||
# The sales service batch endpoints are at /api/v1/batch/... not /api/v1/sales/batch/...
|
||||
target_path = f"/api/v1/batch{path}"
|
||||
return await _proxy_to_sales_service(request, target_path)
|
||||
elif path.startswith("/production-summary"):
|
||||
# Route tenant batch production summary to production service
|
||||
# The production service batch endpoints are at /api/v1/batch/... not /api/v1/production/batch/...
|
||||
target_path = f"/api/v1/batch{path}"
|
||||
return await _proxy_to_production_service(request, target_path)
|
||||
else:
|
||||
# Default to sales service for other tenant batch operations
|
||||
# The service batch endpoints are at /api/v1/batch/... not /api/v1/sales/batch/...
|
||||
target_path = f"/api/v1/batch{path}"
|
||||
return await _proxy_to_sales_service(request, target_path)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED DATA SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/sales{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_all_tenant_sales_alternative(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy all tenant sales requests - handles both base and sub-paths"""
|
||||
base_path = f"/api/v1/tenants/{tenant_id}/sales"
|
||||
|
||||
# If path is empty or just "/", use base path
|
||||
if not path or path == "/" or path == "":
|
||||
target_path = base_path
|
||||
else:
|
||||
# Ensure path starts with "/"
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
target_path = base_path + path
|
||||
|
||||
return await _proxy_to_sales_service(request, target_path)
|
||||
|
||||
|
||||
@router.api_route("/{tenant_id}/enterprise/batch{path:path}", methods=["POST", "GET", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_enterprise_batch(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy enterprise batch requests (spanning multiple tenants within an enterprise) to appropriate services"""
|
||||
# Forward to orchestrator service for enterprise-level operations
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/enterprise/batch{path}".rstrip("/")
|
||||
return await _proxy_to_orchestrator_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/weather/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_tenant_weather(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant weather requests to external service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/weather/{path}".rstrip("/")
|
||||
return await _proxy_to_external_service(request, target_path)
|
||||
|
||||
@router.api_route("/{tenant_id}/traffic/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_tenant_traffic(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant traffic requests to external service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/traffic/{path}".rstrip("/")
|
||||
return await _proxy_to_external_service(request, target_path)
|
||||
|
||||
@router.api_route("/{tenant_id}/external/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_tenant_external(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant external service requests (v2.0 city-based optimized endpoints)"""
|
||||
# Route to external service with normal path structure
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/external/{path}".rstrip("/")
|
||||
return await _proxy_to_external_service(request, target_path)
|
||||
|
||||
# Service-specific analytics routes (must come BEFORE the general analytics route)
|
||||
@router.api_route("/{tenant_id}/procurement/analytics/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_procurement_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant procurement analytics requests to procurement service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/procurement/analytics/{path}".rstrip("/")
|
||||
return await _proxy_to_procurement_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/inventory/analytics/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_inventory_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant inventory analytics requests to inventory service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/inventory/analytics/{path}".rstrip("/")
|
||||
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/production/analytics/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_production_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant production analytics requests to production service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/production/analytics/{path}".rstrip("/")
|
||||
return await _proxy_to_production_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/sales/analytics/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_sales_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant sales analytics requests to sales service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/sales/analytics/{path}".rstrip("/")
|
||||
return await _proxy_to_sales_service(request, target_path)
|
||||
|
||||
@router.api_route("/{tenant_id}/analytics/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_tenant_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant analytics requests to sales service (fallback for non-service-specific analytics)"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/analytics/{path}".rstrip("/")
|
||||
return await _proxy_to_sales_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED AI INSIGHTS ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/insights{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_insights(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant AI insights requests to AI insights service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/insights{path}".rstrip("/")
|
||||
return await _proxy_to_ai_insights_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/onboarding/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_tenant_onboarding(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant onboarding requests to tenant service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/onboarding/{path}".rstrip("/")
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED TRAINING SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/training/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_training(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant training requests to training service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/training/{path}".rstrip("/")
|
||||
return await _proxy_to_training_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/models/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_models(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant model requests to training service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/models/{path}".rstrip("/")
|
||||
return await _proxy_to_training_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/statistics", methods=["GET", "OPTIONS"])
|
||||
async def proxy_tenant_statistics(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy tenant statistics requests to training service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/statistics"
|
||||
return await _proxy_to_training_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED FORECASTING SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/forecasting/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_forecasting(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant forecasting requests to forecasting service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/forecasting/{path}".rstrip("/")
|
||||
return await _proxy_to_forecasting_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/forecasting/enterprise/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_forecasting_enterprise(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant forecasting enterprise requests to forecasting service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/forecasting/enterprise/{path}".rstrip("/")
|
||||
return await _proxy_to_forecasting_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/forecasts/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_forecasts(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant forecast requests to forecasting service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/forecasts/{path}".rstrip("/")
|
||||
return await _proxy_to_forecasting_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/predictions/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_tenant_predictions(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant prediction requests to forecasting service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/predictions/{path}".rstrip("/")
|
||||
return await _proxy_to_forecasting_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED NOTIFICATION SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/notifications/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_notifications(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant notification requests to notification service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/notifications/{path}".rstrip("/")
|
||||
return await _proxy_to_notification_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED ALERT ANALYTICS ENDPOINTS (Must come BEFORE inventory alerts)
|
||||
# ================================================================
|
||||
|
||||
# Exact match for /alerts endpoint (without additional path)
|
||||
@router.api_route("/{tenant_id}/alerts", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_alerts_list(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy tenant alerts list requests to alert processor service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/alerts"
|
||||
return await _proxy_to_alert_processor_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/alerts/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_alert_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant alert analytics requests to alert processor service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/alerts/{path}".rstrip("/")
|
||||
return await _proxy_to_alert_processor_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED INVENTORY SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/alerts{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_alerts_inventory(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant alerts requests to inventory service (legacy/food-safety alerts)"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/alerts{path}".rstrip("/")
|
||||
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/inventory/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_inventory(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant inventory requests to inventory service"""
|
||||
# The inventory service expects /api/v1/tenants/{tenant_id}/inventory/{path}
|
||||
# Keep the full path structure for inventory service
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/inventory/{path}".rstrip("/")
|
||||
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
# Specific route for ingredients without additional path
|
||||
@router.api_route("/{tenant_id}/ingredients", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_ingredients_base(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy tenant ingredient requests to inventory service (base path)"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/ingredients"
|
||||
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/ingredients/count", methods=["GET"])
|
||||
async def proxy_tenant_ingredients_count(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy tenant ingredient count requests to inventory service"""
|
||||
# Inventory service uses RouteBuilder('inventory').build_base_route("ingredients/count")
|
||||
# which generates /api/v1/tenants/{tenant_id}/inventory/ingredients/count
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/inventory/ingredients/count"
|
||||
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/ingredients/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_ingredients_with_path(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant ingredient requests to inventory service (with additional path)"""
|
||||
# The inventory service ingredient endpoints are now tenant-scoped: /api/v1/tenants/{tenant_id}/ingredients/{path}
|
||||
# Keep the full tenant path structure
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/ingredients/{path}".rstrip("/")
|
||||
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/stock/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_stock(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant stock requests to inventory service"""
|
||||
# The inventory service stock endpoints are now tenant-scoped: /api/v1/tenants/{tenant_id}/stock/{path}
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/stock/{path}".rstrip("/")
|
||||
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/dashboard/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_dashboard(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant dashboard requests to orchestrator service"""
|
||||
# The orchestrator service dashboard endpoints are tenant-scoped: /api/v1/tenants/{tenant_id}/dashboard/{path}
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/dashboard/{path}".rstrip("/")
|
||||
return await _proxy_to_orchestrator_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/transformations", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_transformations_base(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy tenant transformations requests to inventory service (base path)"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/transformations"
|
||||
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/transformations/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_transformations_with_path(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant transformations requests to inventory service (with additional path)"""
|
||||
# The inventory service transformations endpoints are tenant-scoped: /api/v1/tenants/{tenant_id}/transformations/{path}
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/transformations/{path}".rstrip("/")
|
||||
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/sustainability/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_sustainability(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant sustainability requests to inventory service"""
|
||||
# The inventory service sustainability endpoints are tenant-scoped: /api/v1/tenants/{tenant_id}/sustainability/{path}
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/sustainability/{path}".rstrip("/")
|
||||
return await _proxy_to_inventory_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED PRODUCTION SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/production/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_production(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant production requests to production service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/production/{path}".rstrip("/")
|
||||
return await _proxy_to_production_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED ORCHESTRATOR SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/enterprise/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_enterprise(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant enterprise dashboard requests to orchestrator service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/enterprise/{path}".rstrip("/")
|
||||
return await _proxy_to_orchestrator_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/orchestrator/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_orchestrator(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant orchestrator requests to orchestrator service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/orchestrator/{path}".rstrip("/")
|
||||
return await _proxy_to_orchestrator_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED ORDERS SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/orders", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_orders_base(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy tenant orders requests to orders service (base path)"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/orders"
|
||||
return await _proxy_to_orders_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/orders/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_orders_with_path(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant orders requests to orders service (with additional path)"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/orders/{path}".rstrip("/")
|
||||
return await _proxy_to_orders_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/customers/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_customers(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant customers requests to orders service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/orders/customers/{path}".rstrip("/")
|
||||
return await _proxy_to_orders_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/procurement/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_procurement(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant procurement requests to procurement service"""
|
||||
# For all procurement routes, we need to maintain the /procurement/ part in the path
|
||||
# The procurement service now uses standardized paths with RouteBuilder
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/procurement/{path}".rstrip("/")
|
||||
return await _proxy_to_procurement_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED SUPPLIER SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/suppliers", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_suppliers_base(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy tenant supplier requests to suppliers service (base path)"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/suppliers"
|
||||
return await _proxy_to_suppliers_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/suppliers/count", methods=["GET"])
|
||||
async def proxy_tenant_suppliers_count(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy tenant supplier count requests to suppliers service"""
|
||||
# Suppliers service uses RouteBuilder('suppliers').build_operations_route("count")
|
||||
# which generates /api/v1/tenants/{tenant_id}/suppliers/operations/count
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/suppliers/operations/count"
|
||||
return await _proxy_to_suppliers_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/suppliers/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_suppliers_with_path(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant supplier requests to suppliers service (with additional path)"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/suppliers/{path}".rstrip("/")
|
||||
return await _proxy_to_suppliers_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
# NOTE: Purchase orders are now accessed via the main procurement route:
|
||||
# /api/v1/tenants/{tenant_id}/procurement/purchase-orders/*
|
||||
# Legacy route removed to enforce standardized structure
|
||||
|
||||
@router.api_route("/{tenant_id}/deliveries{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_deliveries(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant delivery requests to suppliers service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/deliveries{path}".rstrip("/")
|
||||
return await _proxy_to_suppliers_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED LOCATIONS ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/locations", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_locations_base(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy tenant locations requests to tenant service (base path)"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/locations"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/{tenant_id}/locations/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_locations_with_path(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant locations requests to tenant service (with additional path)"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/locations/{path}".rstrip("/")
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED DISTRIBUTION SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/distribution/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_distribution(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant distribution requests to distribution service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/distribution/{path}".rstrip("/")
|
||||
return await _proxy_to_distribution_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED RECIPES SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/recipes", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_recipes_base(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy tenant recipes requests to recipes service (base path)"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/recipes"
|
||||
return await _proxy_to_recipes_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/recipes/count", methods=["GET"])
|
||||
async def proxy_tenant_recipes_count(request: Request, tenant_id: str = Path(...)):
|
||||
"""Proxy tenant recipes count requests to recipes service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/recipes/count"
|
||||
return await _proxy_to_recipes_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/recipes/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_recipes_with_path(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant recipes requests to recipes service (with additional path)"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/recipes/{path}".rstrip("/")
|
||||
return await _proxy_to_recipes_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED POS SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/pos/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_pos(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant POS requests to POS service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/pos/{path}".rstrip("/")
|
||||
return await _proxy_to_pos_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
# ================================================================
|
||||
# PROXY HELPER FUNCTIONS
|
||||
# ================================================================
|
||||
|
||||
async def _proxy_to_tenant_service(request: Request, target_path: str):
|
||||
"""Proxy request to tenant service"""
|
||||
return await _proxy_request(request, target_path, settings.TENANT_SERVICE_URL)
|
||||
|
||||
async def _proxy_to_sales_service(request: Request, target_path: str):
|
||||
"""Proxy request to sales service"""
|
||||
return await _proxy_request(request, target_path, settings.SALES_SERVICE_URL)
|
||||
|
||||
async def _proxy_to_external_service(request: Request, target_path: str):
|
||||
"""Proxy request to external service"""
|
||||
return await _proxy_request(request, target_path, settings.EXTERNAL_SERVICE_URL)
|
||||
|
||||
async def _proxy_to_training_service(request: Request, target_path: str, tenant_id: str = None):
|
||||
"""Proxy request to training service"""
|
||||
return await _proxy_request(request, target_path, settings.TRAINING_SERVICE_URL, tenant_id=tenant_id)
|
||||
|
||||
async def _proxy_to_forecasting_service(request: Request, target_path: str, tenant_id: str = None):
|
||||
"""Proxy request to forecasting service"""
|
||||
return await _proxy_request(request, target_path, settings.FORECASTING_SERVICE_URL, tenant_id=tenant_id)
|
||||
|
||||
async def _proxy_to_notification_service(request: Request, target_path: str):
|
||||
"""Proxy request to notification service"""
|
||||
return await _proxy_request(request, target_path, settings.NOTIFICATION_SERVICE_URL)
|
||||
|
||||
async def _proxy_to_inventory_service(request: Request, target_path: str, tenant_id: str = None):
|
||||
"""Proxy request to inventory service"""
|
||||
return await _proxy_request(request, target_path, settings.INVENTORY_SERVICE_URL, tenant_id=tenant_id)
|
||||
|
||||
async def _proxy_to_production_service(request: Request, target_path: str, tenant_id: str = None):
|
||||
"""Proxy request to production service"""
|
||||
return await _proxy_request(request, target_path, settings.PRODUCTION_SERVICE_URL, tenant_id=tenant_id)
|
||||
|
||||
async def _proxy_to_orders_service(request: Request, target_path: str, tenant_id: str = None):
|
||||
"""Proxy request to orders service"""
|
||||
return await _proxy_request(request, target_path, settings.ORDERS_SERVICE_URL, tenant_id=tenant_id)
|
||||
|
||||
async def _proxy_to_suppliers_service(request: Request, target_path: str, tenant_id: str = None):
|
||||
"""Proxy request to suppliers service"""
|
||||
return await _proxy_request(request, target_path, settings.SUPPLIERS_SERVICE_URL, tenant_id=tenant_id)
|
||||
|
||||
async def _proxy_to_recipes_service(request: Request, target_path: str, tenant_id: str = None):
|
||||
"""Proxy request to recipes service"""
|
||||
return await _proxy_request(request, target_path, settings.RECIPES_SERVICE_URL, tenant_id=tenant_id)
|
||||
|
||||
async def _proxy_to_pos_service(request: Request, target_path: str, tenant_id: str = None):
|
||||
"""Proxy request to POS service"""
|
||||
return await _proxy_request(request, target_path, settings.POS_SERVICE_URL, tenant_id=tenant_id)
|
||||
|
||||
async def _proxy_to_procurement_service(request: Request, target_path: str, tenant_id: str = None):
|
||||
"""Proxy request to procurement service"""
|
||||
return await _proxy_request(request, target_path, settings.PROCUREMENT_SERVICE_URL, tenant_id=tenant_id)
|
||||
|
||||
async def _proxy_to_alert_processor_service(request: Request, target_path: str, tenant_id: str = None):
|
||||
"""Proxy request to alert processor service"""
|
||||
return await _proxy_request(request, target_path, settings.ALERT_PROCESSOR_SERVICE_URL, tenant_id=tenant_id)
|
||||
|
||||
async def _proxy_to_orchestrator_service(request: Request, target_path: str, tenant_id: str = None):
|
||||
"""Proxy request to orchestrator service"""
|
||||
return await _proxy_request(request, target_path, settings.ORCHESTRATOR_SERVICE_URL, tenant_id=tenant_id)
|
||||
|
||||
async def _proxy_to_ai_insights_service(request: Request, target_path: str, tenant_id: str = None):
|
||||
"""Proxy request to AI insights service"""
|
||||
return await _proxy_request(request, target_path, settings.AI_INSIGHTS_SERVICE_URL, tenant_id=tenant_id)
|
||||
|
||||
async def _proxy_to_distribution_service(request: Request, target_path: str, tenant_id: str = None):
|
||||
"""Proxy request to distribution service"""
|
||||
return await _proxy_request(request, target_path, settings.DISTRIBUTION_SERVICE_URL, tenant_id=tenant_id)
|
||||
|
||||
async def _proxy_request(request: Request, target_path: str, service_url: str, tenant_id: str = None):
|
||||
"""Generic proxy function with enhanced error handling"""
|
||||
|
||||
# Handle OPTIONS requests directly for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400"
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
url = f"{service_url}{target_path}"
|
||||
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Add tenant ID header if provided (override if needed)
|
||||
if tenant_id:
|
||||
headers["x-tenant-id"] = tenant_id
|
||||
|
||||
# Debug logging
|
||||
user_context = getattr(request.state, 'user', None)
|
||||
if user_context:
|
||||
logger.info(f"Forwarding request to {url} with user context: user_id={user_context.get('user_id')}, email={user_context.get('email')}, tenant_id={tenant_id}, subscription_tier={user_context.get('subscription_tier', 'not_set')}")
|
||||
else:
|
||||
logger.warning(f"No user context available when forwarding request to {url}. request.state.user: {getattr(request.state, 'user', 'NOT_SET')}")
|
||||
|
||||
# Get request body if present
|
||||
body = None
|
||||
files = None
|
||||
data = None
|
||||
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
content_type = request.headers.get("content-type", "")
|
||||
logger.info(f"Processing {request.method} request with content-type: {content_type}")
|
||||
|
||||
# Handle multipart/form-data (file uploads)
|
||||
if "multipart/form-data" in content_type:
|
||||
logger.info("Detected multipart/form-data, parsing form...")
|
||||
# For multipart/form-data, we need to re-parse and forward as files
|
||||
form = await request.form()
|
||||
logger.info(f"Form parsed, found {len(form)} fields: {list(form.keys())}")
|
||||
|
||||
# Extract files and form fields separately
|
||||
files_dict = {}
|
||||
data_dict = {}
|
||||
|
||||
for key, value in form.items():
|
||||
if hasattr(value, 'file'): # It's a file
|
||||
# Read file content
|
||||
file_content = await value.read()
|
||||
files_dict[key] = (value.filename, file_content, value.content_type)
|
||||
logger.info(f"Found file field '{key}': filename={value.filename}, size={len(file_content)}, type={value.content_type}")
|
||||
else: # It's a regular form field
|
||||
data_dict[key] = value
|
||||
logger.info(f"Found form field '{key}': value={value}")
|
||||
|
||||
files = files_dict if files_dict else None
|
||||
data = data_dict if data_dict else None
|
||||
|
||||
logger.info(f"Forwarding multipart request with files={list(files.keys()) if files else None}, data={list(data.keys()) if data else None}")
|
||||
|
||||
# For multipart requests, we need to get fresh headers since httpx will set content-type
|
||||
# Get all headers again to ensure we have the complete set
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
# httpx will automatically set content-type for multipart, so we don't need to remove it
|
||||
else:
|
||||
# For other content types, use body as before
|
||||
body = await request.body()
|
||||
logger.info(f"Using raw body, size: {len(body)} bytes")
|
||||
|
||||
# Add query parameters
|
||||
params = dict(request.query_params)
|
||||
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=30.0, # Connection timeout
|
||||
read=600.0, # Read timeout: 10 minutes (was 30s)
|
||||
write=30.0, # Write timeout
|
||||
pool=30.0 # Pool timeout
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
files=files,
|
||||
data=data,
|
||||
params=params
|
||||
)
|
||||
|
||||
# Handle different response types
|
||||
if response.headers.get("content-type", "").startswith("application/json"):
|
||||
try:
|
||||
content = response.json()
|
||||
except:
|
||||
content = {"message": "Invalid JSON response from service"}
|
||||
else:
|
||||
content = response.text
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response.status_code,
|
||||
content=content
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error proxying to {service_url}{target_path}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal gateway error"
|
||||
)
|
||||
205
gateway/app/routes/user.py
Normal file
205
gateway/app/routes/user.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# ================================================================
|
||||
# gateway/app/routes/user.py
|
||||
# ================================================================
|
||||
"""
|
||||
Authentication routes for API Gateway
|
||||
"""
|
||||
|
||||
import logging
|
||||
import httpx
|
||||
from fastapi import APIRouter, Request, Response, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Dict, Any
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
from app.core.service_discovery import ServiceDiscovery
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize service discovery and metrics
|
||||
service_discovery = ServiceDiscovery()
|
||||
metrics = MetricsCollector("gateway")
|
||||
|
||||
# Auth service configuration
|
||||
AUTH_SERVICE_URL = settings.AUTH_SERVICE_URL or "http://auth-service:8000"
|
||||
|
||||
class UserProxy:
|
||||
"""Authentication service proxy with enhanced error handling"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(30.0),
|
||||
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20)
|
||||
)
|
||||
|
||||
async def forward_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
request: Request
|
||||
) -> Response:
|
||||
"""Forward request to auth service with proper error handling"""
|
||||
|
||||
# Handle OPTIONS requests directly for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
# Get auth service URL (with service discovery if available)
|
||||
auth_url = await self._get_auth_service_url()
|
||||
# FIX: Auth service uses /api/v1/auth/ prefix, not /api/v1/users/
|
||||
target_url = f"{auth_url}/api/v1/auth/{path}"
|
||||
|
||||
# Prepare headers (remove hop-by-hop headers)
|
||||
# IMPORTANT: Use request.headers directly to get headers added by middleware
|
||||
# Also check request.state for headers injected by middleware
|
||||
headers = self._prepare_headers(request.headers, request)
|
||||
|
||||
# Get request body
|
||||
body = await request.body()
|
||||
|
||||
# Forward request
|
||||
logger.info(f"Forwarding {method} {path} to auth service")
|
||||
|
||||
response = await self.client.request(
|
||||
method=method,
|
||||
url=target_url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
params=dict(request.query_params)
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
metrics.increment_counter("gateway_auth_requests_total")
|
||||
metrics.increment_counter(
|
||||
"gateway_auth_responses_total",
|
||||
labels={"status_code": str(response.status_code)}
|
||||
)
|
||||
|
||||
# Prepare response headers
|
||||
response_headers = self._prepare_response_headers(dict(response.headers))
|
||||
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=response_headers,
|
||||
media_type=response.headers.get("content-type")
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error(f"Timeout forwarding {method} {path} to auth service")
|
||||
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "timeout"})
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="Authentication service timeout"
|
||||
)
|
||||
|
||||
except httpx.ConnectError:
|
||||
logger.error(f"Connection error forwarding {method} {path} to auth service")
|
||||
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "connection"})
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication service unavailable"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding {method} {path} to auth service: {e}")
|
||||
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "unknown"})
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal gateway error"
|
||||
)
|
||||
|
||||
async def _get_auth_service_url(self) -> str:
|
||||
"""Get auth service URL with service discovery"""
|
||||
try:
|
||||
# Try service discovery first
|
||||
service_url = await service_discovery.get_service_url("auth-service")
|
||||
if service_url:
|
||||
return service_url
|
||||
except Exception as e:
|
||||
logger.warning(f"Service discovery failed: {e}")
|
||||
|
||||
# Fall back to configured URL
|
||||
return AUTH_SERVICE_URL
|
||||
|
||||
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
|
||||
"""Prepare headers for forwarding using unified HeaderManager"""
|
||||
# Use unified HeaderManager to get all headers
|
||||
if request:
|
||||
all_headers = header_manager.get_all_headers_for_proxy(request)
|
||||
else:
|
||||
# Fallback: convert headers to dict manually
|
||||
all_headers = {}
|
||||
if hasattr(headers, '_list'):
|
||||
for k, v in headers.__dict__.get('_list', []):
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
elif hasattr(headers, 'raw'):
|
||||
for k, v in headers.raw:
|
||||
key = k.decode() if isinstance(k, bytes) else k
|
||||
value = v.decode() if isinstance(v, bytes) else v
|
||||
all_headers[key] = value
|
||||
else:
|
||||
# Headers is already a dict
|
||||
all_headers = dict(headers)
|
||||
|
||||
return all_headers
|
||||
|
||||
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Prepare response headers"""
|
||||
# Remove server-specific headers
|
||||
filtered_headers = {
|
||||
k: v for k, v in headers.items()
|
||||
if k.lower() not in {'server', 'date'}
|
||||
}
|
||||
|
||||
# Add CORS headers if needed
|
||||
if settings.CORS_ORIGINS:
|
||||
filtered_headers['Access-Control-Allow-Origin'] = '*'
|
||||
filtered_headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
|
||||
filtered_headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
|
||||
|
||||
return filtered_headers
|
||||
|
||||
# Initialize proxy
|
||||
user_proxy = UserProxy()
|
||||
|
||||
# ================================================================
|
||||
# USER MANAGEMENT ENDPOINTS - Proxied to auth service
|
||||
# ================================================================
|
||||
|
||||
|
||||
|
||||
@router.get("/delete/{user_id}/deletion-preview")
|
||||
async def preview_user_deletion(user_id: str, request: Request):
|
||||
"""Proxy user deletion preview to auth service"""
|
||||
return await user_proxy.forward_request("GET", f"delete/{user_id}/deletion-preview", request)
|
||||
|
||||
@router.delete("/delete/{user_id}")
|
||||
async def delete_user(user_id: str, request: Request):
|
||||
"""Proxy admin user deletion to auth service"""
|
||||
return await user_proxy.forward_request("DELETE", f"delete/{user_id}", request)
|
||||
|
||||
# ================================================================
|
||||
# CATCH-ALL ROUTE for any other user endpoints
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
||||
async def proxy_auth_requests(path: str, request: Request):
|
||||
"""Catch-all proxy for auth requests"""
|
||||
return await user_proxy.forward_request(request.method, path, request)
|
||||
116
gateway/app/routes/webhooks.py
Normal file
116
gateway/app/routes/webhooks.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Webhook routes for API Gateway - Handles webhook endpoints
|
||||
|
||||
Route Configuration Notes:
|
||||
- Stripe configures webhook URL as: https://domain.com/api/v1/webhooks/stripe
|
||||
- Gateway receives /api/v1/webhooks/* routes and proxies to tenant service at /webhooks/*
|
||||
- Gateway routes use /api/v1 prefix, but tenant service routes use /webhooks/* prefix
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, Response, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.header_manager import header_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# ================================================================
|
||||
# WEBHOOK ENDPOINTS - Direct routing to tenant service
|
||||
# All routes use /api/v1 prefix for consistency
|
||||
# ================================================================
|
||||
|
||||
# Stripe webhook endpoint
|
||||
@router.post("/api/v1/webhooks/stripe")
|
||||
async def proxy_stripe_webhook(request: Request):
|
||||
"""Proxy Stripe webhook requests to tenant service (path: /webhooks/stripe)"""
|
||||
logger.info("Received Stripe webhook at /api/v1/webhooks/stripe")
|
||||
return await _proxy_to_tenant_service(request, "/webhooks/stripe")
|
||||
|
||||
# Generic webhook endpoint
|
||||
@router.post("/api/v1/webhooks/generic")
|
||||
async def proxy_generic_webhook(request: Request):
|
||||
"""Proxy generic webhook requests to tenant service (path: /webhooks/generic)"""
|
||||
return await _proxy_to_tenant_service(request, "/webhooks/generic")
|
||||
|
||||
# ================================================================
|
||||
# PROXY HELPER FUNCTIONS
|
||||
# ================================================================
|
||||
|
||||
async def _proxy_to_tenant_service(request: Request, target_path: str):
|
||||
"""Proxy request to tenant service"""
|
||||
return await _proxy_request(request, target_path, settings.TENANT_SERVICE_URL)
|
||||
|
||||
async def _proxy_request(request: Request, target_path: str, service_url: str):
|
||||
"""Generic proxy function with enhanced error handling"""
|
||||
|
||||
# Handle OPTIONS requests directly for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID, Stripe-Signature",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400"
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
url = f"{service_url}{target_path}"
|
||||
|
||||
# Use unified HeaderManager for consistent header forwarding
|
||||
headers = header_manager.get_all_headers_for_proxy(request)
|
||||
|
||||
# Debug logging
|
||||
logger.info(f"Forwarding webhook request to {url}")
|
||||
|
||||
# Get request body if present
|
||||
body = None
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
# Add query parameters
|
||||
params = dict(request.query_params)
|
||||
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=30.0,
|
||||
read=60.0,
|
||||
write=30.0,
|
||||
pool=30.0
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
params=params
|
||||
)
|
||||
|
||||
# Handle different response types
|
||||
if response.headers.get("content-type", "").startswith("application/json"):
|
||||
try:
|
||||
content = response.json()
|
||||
except:
|
||||
content = {"message": "Invalid JSON response from service"}
|
||||
else:
|
||||
content = response.text
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response.status_code,
|
||||
content=content
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error proxying webhook request to {service_url}{target_path}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal gateway error"
|
||||
)
|
||||
331
gateway/app/utils/subscription_error_responses.py
Normal file
331
gateway/app/utils/subscription_error_responses.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Enhanced Subscription Error Responses
|
||||
|
||||
Provides detailed, conversion-optimized error responses when users
|
||||
hit subscription tier restrictions (HTTP 402 Payment Required).
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UpgradeBenefit(BaseModel):
|
||||
"""A single benefit of upgrading"""
|
||||
text: str
|
||||
icon: str # Icon name (e.g., 'zap', 'trending-up', 'shield')
|
||||
|
||||
|
||||
class ROIEstimate(BaseModel):
|
||||
"""ROI estimate for upgrade"""
|
||||
monthly_savings_min: int
|
||||
monthly_savings_max: int
|
||||
currency: str = "€"
|
||||
payback_period_days: int
|
||||
|
||||
|
||||
class FeatureRestrictionDetail(BaseModel):
|
||||
"""Detailed error response for feature restrictions"""
|
||||
error: str = "subscription_tier_insufficient"
|
||||
code: str = "SUBSCRIPTION_UPGRADE_REQUIRED"
|
||||
status_code: int = 402
|
||||
message: str
|
||||
details: Dict[str, Any]
|
||||
|
||||
|
||||
# Feature-specific upgrade messages
|
||||
FEATURE_MESSAGES = {
|
||||
'analytics': {
|
||||
'title': 'Unlock Advanced Analytics',
|
||||
'description': 'Get deeper insights into your bakery performance with advanced analytics dashboards.',
|
||||
'benefits': [
|
||||
UpgradeBenefit(text='90-day forecast horizon (vs 7 days)', icon='calendar'),
|
||||
UpgradeBenefit(text='Weather & traffic integration', icon='cloud'),
|
||||
UpgradeBenefit(text='What-if scenario modeling', icon='trending-up'),
|
||||
UpgradeBenefit(text='Custom reports & dashboards', icon='bar-chart'),
|
||||
UpgradeBenefit(text='Profitability analysis by product', icon='dollar-sign')
|
||||
],
|
||||
'roi': ROIEstimate(
|
||||
monthly_savings_min=800,
|
||||
monthly_savings_max=1200,
|
||||
payback_period_days=7
|
||||
)
|
||||
},
|
||||
'multi_location': {
|
||||
'title': 'Scale to Multiple Locations',
|
||||
'description': 'Manage up to 3 bakery locations with centralized inventory and analytics.',
|
||||
'benefits': [
|
||||
UpgradeBenefit(text='Up to 3 locations (vs 1)', icon='map-pin'),
|
||||
UpgradeBenefit(text='Inventory transfer between locations', icon='arrow-right'),
|
||||
UpgradeBenefit(text='Location comparison analytics', icon='bar-chart'),
|
||||
UpgradeBenefit(text='Centralized reporting', icon='file-text'),
|
||||
UpgradeBenefit(text='500 products (vs 50)', icon='package')
|
||||
],
|
||||
'roi': ROIEstimate(
|
||||
monthly_savings_min=1000,
|
||||
monthly_savings_max=2000,
|
||||
payback_period_days=10
|
||||
)
|
||||
},
|
||||
'pos_integration': {
|
||||
'title': 'Integrate Your POS System',
|
||||
'description': 'Automatically sync sales data from your point-of-sale system.',
|
||||
'benefits': [
|
||||
UpgradeBenefit(text='Automatic sales import', icon='refresh-cw'),
|
||||
UpgradeBenefit(text='Real-time inventory sync', icon='zap'),
|
||||
UpgradeBenefit(text='Save 10+ hours/week on data entry', icon='clock'),
|
||||
UpgradeBenefit(text='Eliminate manual errors', icon='check-circle'),
|
||||
UpgradeBenefit(text='Faster, more accurate forecasts', icon='trending-up')
|
||||
],
|
||||
'roi': ROIEstimate(
|
||||
monthly_savings_min=600,
|
||||
monthly_savings_max=1000,
|
||||
payback_period_days=5
|
||||
)
|
||||
},
|
||||
'advanced_forecasting': {
|
||||
'title': 'Unlock Advanced AI Forecasting',
|
||||
'description': 'Get more accurate predictions with weather, traffic, and seasonal patterns.',
|
||||
'benefits': [
|
||||
UpgradeBenefit(text='Weather-based demand predictions', icon='cloud'),
|
||||
UpgradeBenefit(text='Traffic & event impact analysis', icon='activity'),
|
||||
UpgradeBenefit(text='Seasonal pattern detection', icon='calendar'),
|
||||
UpgradeBenefit(text='15% more accurate forecasts', icon='target'),
|
||||
UpgradeBenefit(text='Reduce waste by 7+ percentage points', icon='trending-down')
|
||||
],
|
||||
'roi': ROIEstimate(
|
||||
monthly_savings_min=800,
|
||||
monthly_savings_max=1500,
|
||||
payback_period_days=7
|
||||
)
|
||||
},
|
||||
'scenario_modeling': {
|
||||
'title': 'Plan with What-If Scenarios',
|
||||
'description': 'Model different business scenarios before making decisions.',
|
||||
'benefits': [
|
||||
UpgradeBenefit(text='Test menu changes before launch', icon='beaker'),
|
||||
UpgradeBenefit(text='Optimize pricing strategies', icon='dollar-sign'),
|
||||
UpgradeBenefit(text='Plan seasonal inventory', icon='calendar'),
|
||||
UpgradeBenefit(text='Risk assessment tools', icon='shield'),
|
||||
UpgradeBenefit(text='Data-driven decision making', icon='trending-up')
|
||||
],
|
||||
'roi': ROIEstimate(
|
||||
monthly_savings_min=500,
|
||||
monthly_savings_max=1000,
|
||||
payback_period_days=10
|
||||
)
|
||||
},
|
||||
'api_access': {
|
||||
'title': 'Integrate with Your Tools',
|
||||
'description': 'Connect bakery.ai with your existing business systems via API.',
|
||||
'benefits': [
|
||||
UpgradeBenefit(text='Full REST API access', icon='code'),
|
||||
UpgradeBenefit(text='1,000 API calls/hour (vs 100)', icon='zap'),
|
||||
UpgradeBenefit(text='Webhook support for real-time events', icon='bell'),
|
||||
UpgradeBenefit(text='Custom integrations', icon='link'),
|
||||
UpgradeBenefit(text='API documentation & support', icon='book')
|
||||
],
|
||||
'roi': None # ROI varies by use case
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_upgrade_required_response(
|
||||
feature: str,
|
||||
current_tier: str,
|
||||
required_tier: str = 'professional',
|
||||
allowed_tiers: Optional[List[str]] = None,
|
||||
custom_message: Optional[str] = None
|
||||
) -> FeatureRestrictionDetail:
|
||||
"""
|
||||
Create an enhanced 402 error response with upgrade suggestions
|
||||
|
||||
Args:
|
||||
feature: Feature key (e.g., 'analytics', 'multi_location')
|
||||
current_tier: User's current subscription tier
|
||||
required_tier: Minimum tier required for this feature
|
||||
allowed_tiers: List of tiers that have access (defaults to [required_tier, 'enterprise'])
|
||||
custom_message: Optional custom message (overrides default)
|
||||
|
||||
Returns:
|
||||
FeatureRestrictionDetail with upgrade information
|
||||
"""
|
||||
if allowed_tiers is None:
|
||||
allowed_tiers = [required_tier, 'enterprise'] if required_tier != 'enterprise' else ['enterprise']
|
||||
|
||||
# Get feature-specific messaging
|
||||
feature_info = FEATURE_MESSAGES.get(feature, {
|
||||
'title': f'Upgrade to {required_tier.capitalize()}',
|
||||
'description': f'This feature requires a {required_tier.capitalize()} subscription.',
|
||||
'benefits': [],
|
||||
'roi': None
|
||||
})
|
||||
|
||||
# Build detailed response
|
||||
message = custom_message or feature_info['title']
|
||||
|
||||
details = {
|
||||
'required_feature': feature,
|
||||
'minimum_tier': required_tier,
|
||||
'allowed_tiers': allowed_tiers,
|
||||
'current_tier': current_tier,
|
||||
|
||||
# Upgrade messaging
|
||||
'title': feature_info['title'],
|
||||
'description': feature_info['description'],
|
||||
'benefits': [b.dict() for b in feature_info['benefits']],
|
||||
|
||||
# ROI information
|
||||
'roi_estimate': feature_info['roi'].dict() if feature_info['roi'] else None,
|
||||
|
||||
# Call-to-action
|
||||
'upgrade_url': f'/app/settings/subscription?upgrade={required_tier}&from={current_tier}&feature={feature}',
|
||||
'preview_url': f'/app/{feature}?demo=true' if feature in ['analytics'] else None,
|
||||
|
||||
# Suggested tier
|
||||
'suggested_tier': required_tier,
|
||||
'suggested_tier_display': required_tier.capitalize(),
|
||||
|
||||
# Additional context
|
||||
'can_preview': feature in ['analytics'],
|
||||
'has_free_trial': True,
|
||||
'trial_days': 0,
|
||||
|
||||
# Social proof
|
||||
'social_proof': get_social_proof_message(required_tier),
|
||||
|
||||
# Pricing context
|
||||
'pricing_context': get_pricing_context(required_tier)
|
||||
}
|
||||
|
||||
return FeatureRestrictionDetail(
|
||||
message=message,
|
||||
details=details
|
||||
)
|
||||
|
||||
|
||||
def create_quota_exceeded_response(
|
||||
metric: str,
|
||||
current: int,
|
||||
limit: int,
|
||||
current_tier: str,
|
||||
upgrade_tier: str = 'professional',
|
||||
upgrade_limit: Optional[int] = None,
|
||||
reset_at: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create an enhanced 429 error response for quota limits
|
||||
|
||||
Args:
|
||||
metric: The quota metric (e.g., 'training_jobs', 'forecasts')
|
||||
current: Current usage
|
||||
limit: Quota limit
|
||||
current_tier: User's current subscription tier
|
||||
upgrade_tier: Suggested upgrade tier
|
||||
upgrade_limit: Limit in upgraded tier (None = unlimited)
|
||||
reset_at: When the quota resets (ISO datetime string)
|
||||
|
||||
Returns:
|
||||
Error response with upgrade suggestions
|
||||
"""
|
||||
metric_labels = {
|
||||
'training_jobs': 'Training Jobs',
|
||||
'forecasts': 'Forecasts',
|
||||
'api_calls': 'API Calls',
|
||||
'products': 'Products',
|
||||
'users': 'Users',
|
||||
'locations': 'Locations'
|
||||
}
|
||||
|
||||
label = metric_labels.get(metric, metric.replace('_', ' ').title())
|
||||
|
||||
return {
|
||||
'error': 'quota_exceeded',
|
||||
'code': 'QUOTA_LIMIT_REACHED',
|
||||
'status_code': 429,
|
||||
'message': f'Daily quota exceeded for {label.lower()}',
|
||||
'details': {
|
||||
'metric': metric,
|
||||
'label': label,
|
||||
'current': current,
|
||||
'limit': limit,
|
||||
'reset_at': reset_at,
|
||||
'quota_type': metric,
|
||||
|
||||
# Upgrade suggestion
|
||||
'can_upgrade': True,
|
||||
'upgrade_tier': upgrade_tier,
|
||||
'upgrade_limit': upgrade_limit,
|
||||
'upgrade_benefit': f'{upgrade_limit}x more capacity' if upgrade_limit and limit else 'Unlimited capacity',
|
||||
|
||||
# Call-to-action
|
||||
'upgrade_url': f'/app/settings/subscription?upgrade={upgrade_tier}&from={current_tier}&reason=quota_exceeded&metric={metric}',
|
||||
|
||||
# ROI context
|
||||
'roi_message': get_quota_roi_message(metric, current_tier, upgrade_tier)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_social_proof_message(tier: str) -> str:
|
||||
"""Get social proof message for a tier"""
|
||||
messages = {
|
||||
'professional': '87% of growing bakeries choose Professional',
|
||||
'enterprise': 'Trusted by multi-location bakery chains across Europe'
|
||||
}
|
||||
return messages.get(tier, '')
|
||||
|
||||
|
||||
def get_pricing_context(tier: str) -> Dict[str, Any]:
|
||||
"""Get pricing context for a tier"""
|
||||
pricing = {
|
||||
'professional': {
|
||||
'monthly_price': 149,
|
||||
'yearly_price': 1490,
|
||||
'per_day_cost': 4.97,
|
||||
'currency': '€',
|
||||
'savings_yearly': 596,
|
||||
'value_message': 'Only €4.97/day for unlimited growth'
|
||||
},
|
||||
'enterprise': {
|
||||
'monthly_price': 499,
|
||||
'yearly_price': 4990,
|
||||
'per_day_cost': 16.63,
|
||||
'currency': '€',
|
||||
'savings_yearly': 1998,
|
||||
'value_message': 'Complete solution for €16.63/day'
|
||||
}
|
||||
}
|
||||
return pricing.get(tier, {})
|
||||
|
||||
|
||||
def get_quota_roi_message(metric: str, current_tier: str, upgrade_tier: str) -> str:
|
||||
"""Get ROI-focused message for quota upgrades"""
|
||||
messages = {
|
||||
'training_jobs': 'More training = better predictions = less waste',
|
||||
'forecasts': 'Run forecasts for all products daily to optimize inventory',
|
||||
'products': 'Expand your menu without limits',
|
||||
'users': 'Give your entire team access to real-time data',
|
||||
'locations': 'Manage all your bakeries from one platform'
|
||||
}
|
||||
return messages.get(metric, 'Unlock more capacity to grow your business')
|
||||
|
||||
|
||||
# Example usage function for gateway middleware
|
||||
def handle_feature_restriction(
|
||||
feature: str,
|
||||
current_tier: str,
|
||||
required_tier: str = 'professional'
|
||||
) -> tuple[int, Dict[str, Any]]:
|
||||
"""
|
||||
Handle feature restriction in gateway middleware
|
||||
|
||||
Returns:
|
||||
(status_code, response_body)
|
||||
"""
|
||||
response = create_upgrade_required_response(
|
||||
feature=feature,
|
||||
current_tier=current_tier,
|
||||
required_tier=required_tier
|
||||
)
|
||||
|
||||
return response.status_code, response.dict()
|
||||
Reference in New Issue
Block a user