Initial commit - production deployment
This commit is contained in:
0
gateway/app/core/__init__.py
Normal file
0
gateway/app/core/__init__.py
Normal file
46
gateway/app/core/config.py
Normal file
46
gateway/app/core/config.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# ================================================================
|
||||
# GATEWAY SERVICE CONFIGURATION
|
||||
# gateway/app/core/config.py
|
||||
# ================================================================
|
||||
|
||||
"""
|
||||
Gateway service configuration
|
||||
Central API Gateway for all microservices
|
||||
"""
|
||||
|
||||
from shared.config.base import BaseServiceSettings
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
class GatewaySettings(BaseServiceSettings):
|
||||
"""Gateway-specific settings"""
|
||||
|
||||
# Service Identity
|
||||
APP_NAME: str = "Bakery Forecasting Gateway"
|
||||
SERVICE_NAME: str = "gateway"
|
||||
DESCRIPTION: str = "API Gateway for Bakery Forecasting Platform"
|
||||
|
||||
# Gateway-specific Redis database
|
||||
REDIS_DB: int = 6
|
||||
|
||||
# Service Discovery
|
||||
CONSUL_URL: str = os.getenv("CONSUL_URL", "http://consul:8500")
|
||||
ENABLE_SERVICE_DISCOVERY: bool = os.getenv("ENABLE_SERVICE_DISCOVERY", "false").lower() == "true"
|
||||
|
||||
# Load Balancing
|
||||
ENABLE_LOAD_BALANCING: bool = os.getenv("ENABLE_LOAD_BALANCING", "true").lower() == "true"
|
||||
LOAD_BALANCER_ALGORITHM: str = os.getenv("LOAD_BALANCER_ALGORITHM", "round_robin")
|
||||
|
||||
# Circuit Breaker
|
||||
CIRCUIT_BREAKER_ENABLED: bool = os.getenv("CIRCUIT_BREAKER_ENABLED", "true").lower() == "true"
|
||||
CIRCUIT_BREAKER_FAILURE_THRESHOLD: int = int(os.getenv("CIRCUIT_BREAKER_FAILURE_THRESHOLD", "5"))
|
||||
CIRCUIT_BREAKER_RECOVERY_TIMEOUT: int = int(os.getenv("CIRCUIT_BREAKER_RECOVERY_TIMEOUT", "60"))
|
||||
|
||||
# Request/Response Settings
|
||||
MAX_REQUEST_SIZE: int = int(os.getenv("MAX_REQUEST_SIZE", "10485760")) # 10MB
|
||||
REQUEST_TIMEOUT: int = int(os.getenv("REQUEST_TIMEOUT", "30"))
|
||||
|
||||
# Gateway doesn't need a database
|
||||
DATABASE_URL: str = ""
|
||||
|
||||
settings = GatewaySettings()
|
||||
346
gateway/app/core/header_manager.py
Normal file
346
gateway/app/core/header_manager.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
Unified Header Management System for API Gateway
|
||||
Centralized header injection, forwarding, and validation
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import Request
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class HeaderManager:
|
||||
"""
|
||||
Centralized header management for consistent header handling across gateway
|
||||
"""
|
||||
|
||||
# Standard header names (lowercase for consistency)
|
||||
STANDARD_HEADERS = {
|
||||
'user_id': 'x-user-id',
|
||||
'user_email': 'x-user-email',
|
||||
'user_role': 'x-user-role',
|
||||
'user_type': 'x-user-type',
|
||||
'service_name': 'x-service-name',
|
||||
'tenant_id': 'x-tenant-id',
|
||||
'subscription_tier': 'x-subscription-tier',
|
||||
'subscription_status': 'x-subscription-status',
|
||||
'is_demo': 'x-is-demo',
|
||||
'demo_session_id': 'x-demo-session-id',
|
||||
'demo_account_type': 'x-demo-account-type',
|
||||
'tenant_access_type': 'x-tenant-access-type',
|
||||
'can_view_children': 'x-can-view-children',
|
||||
'parent_tenant_id': 'x-parent-tenant-id',
|
||||
'forwarded_by': 'x-forwarded-by',
|
||||
'request_id': 'x-request-id'
|
||||
}
|
||||
|
||||
# Headers that should be sanitized/removed from incoming requests
|
||||
SANITIZED_HEADERS = [
|
||||
'x-subscription-',
|
||||
'x-user-',
|
||||
'x-tenant-',
|
||||
'x-demo-',
|
||||
'x-forwarded-by'
|
||||
]
|
||||
|
||||
# Headers that should be forwarded to downstream services
|
||||
FORWARDABLE_HEADERS = [
|
||||
'authorization',
|
||||
'content-type',
|
||||
'accept',
|
||||
'accept-language',
|
||||
'user-agent',
|
||||
'x-internal-service', # Required for internal service-to-service ML/alert triggers
|
||||
'stripe-signature' # Required for Stripe webhook signature verification
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._initialized = False
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize header manager"""
|
||||
if not self._initialized:
|
||||
logger.info("HeaderManager initialized")
|
||||
self._initialized = True
|
||||
|
||||
def sanitize_incoming_headers(self, request: Request) -> None:
|
||||
"""
|
||||
Remove sensitive headers from incoming request to prevent spoofing
|
||||
"""
|
||||
if not hasattr(request.headers, '_list'):
|
||||
return
|
||||
|
||||
# Filter out headers that start with sanitized prefixes
|
||||
sanitized_headers = [
|
||||
(k, v) for k, v in request.headers.raw
|
||||
if not any(k.decode().lower().startswith(prefix.lower())
|
||||
for prefix in self.SANITIZED_HEADERS)
|
||||
]
|
||||
|
||||
request.headers.__dict__["_list"] = sanitized_headers
|
||||
logger.debug("Sanitized incoming headers")
|
||||
|
||||
def inject_context_headers(self, request: Request, user_context: Dict[str, Any],
|
||||
tenant_id: Optional[str] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Inject standardized context headers into request
|
||||
Returns dict of injected headers for reference
|
||||
"""
|
||||
injected_headers = {}
|
||||
|
||||
# Ensure headers list exists
|
||||
if not hasattr(request.headers, '_list'):
|
||||
request.headers.__dict__["_list"] = []
|
||||
|
||||
# Store headers in request.state for cross-middleware access
|
||||
request.state.injected_headers = {}
|
||||
|
||||
# User context headers
|
||||
if user_context.get('user_id'):
|
||||
header_name = self.STANDARD_HEADERS['user_id']
|
||||
header_value = str(user_context['user_id'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('email'):
|
||||
header_name = self.STANDARD_HEADERS['user_email']
|
||||
header_value = str(user_context['email'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('role'):
|
||||
header_name = self.STANDARD_HEADERS['user_role']
|
||||
header_value = str(user_context['role'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# User type (service vs regular user)
|
||||
if user_context.get('type'):
|
||||
header_name = self.STANDARD_HEADERS['user_type']
|
||||
header_value = str(user_context['type'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Service name for service tokens
|
||||
if user_context.get('service'):
|
||||
header_name = self.STANDARD_HEADERS['service_name']
|
||||
header_value = str(user_context['service'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Tenant context
|
||||
if tenant_id:
|
||||
header_name = self.STANDARD_HEADERS['tenant_id']
|
||||
header_value = str(tenant_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Subscription context
|
||||
if user_context.get('subscription_tier'):
|
||||
header_name = self.STANDARD_HEADERS['subscription_tier']
|
||||
header_value = str(user_context['subscription_tier'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('subscription_status'):
|
||||
header_name = self.STANDARD_HEADERS['subscription_status']
|
||||
header_value = str(user_context['subscription_status'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Demo session context
|
||||
is_demo = user_context.get('is_demo', False)
|
||||
if is_demo:
|
||||
header_name = self.STANDARD_HEADERS['is_demo']
|
||||
header_value = "true"
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('demo_session_id'):
|
||||
header_name = self.STANDARD_HEADERS['demo_session_id']
|
||||
header_value = str(user_context['demo_session_id'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
if user_context.get('demo_account_type'):
|
||||
header_name = self.STANDARD_HEADERS['demo_account_type']
|
||||
header_value = str(user_context['demo_account_type'])
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Hierarchical access context
|
||||
if tenant_id:
|
||||
tenant_access_type = getattr(request.state, 'tenant_access_type', 'direct')
|
||||
can_view_children = getattr(request.state, 'can_view_children', False)
|
||||
|
||||
header_name = self.STANDARD_HEADERS['tenant_access_type']
|
||||
header_value = str(tenant_access_type)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
header_name = self.STANDARD_HEADERS['can_view_children']
|
||||
header_value = str(can_view_children).lower()
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Parent tenant ID if hierarchical access
|
||||
parent_tenant_id = getattr(request.state, 'parent_tenant_id', None)
|
||||
if parent_tenant_id:
|
||||
header_name = self.STANDARD_HEADERS['parent_tenant_id']
|
||||
header_value = str(parent_tenant_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Gateway identification
|
||||
header_name = self.STANDARD_HEADERS['forwarded_by']
|
||||
header_value = "bakery-gateway"
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Request ID if available
|
||||
request_id = getattr(request.state, 'request_id', None)
|
||||
if request_id:
|
||||
header_name = self.STANDARD_HEADERS['request_id']
|
||||
header_value = str(request_id)
|
||||
self._add_header(request, header_name, header_value)
|
||||
injected_headers[header_name] = header_value
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
logger.info("🔧 Injected context headers",
|
||||
user_id=user_context.get('user_id'),
|
||||
user_type=user_context.get('type', ''),
|
||||
service_name=user_context.get('service', ''),
|
||||
role=user_context.get('role', ''),
|
||||
tenant_id=tenant_id,
|
||||
is_demo=is_demo,
|
||||
demo_session_id=user_context.get('demo_session_id', ''),
|
||||
path=request.url.path)
|
||||
|
||||
return injected_headers
|
||||
|
||||
def _add_header(self, request: Request, header_name: str, header_value: str) -> None:
|
||||
"""
|
||||
Safely add header to request
|
||||
"""
|
||||
try:
|
||||
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add header {header_name}: {e}")
|
||||
|
||||
def get_forwardable_headers(self, request: Request) -> Dict[str, str]:
|
||||
"""
|
||||
Get headers that should be forwarded to downstream services
|
||||
Includes both original request headers and injected context headers
|
||||
"""
|
||||
forwardable_headers = {}
|
||||
|
||||
# Add forwardable original headers
|
||||
for header_name in self.FORWARDABLE_HEADERS:
|
||||
header_value = request.headers.get(header_name)
|
||||
if header_value:
|
||||
forwardable_headers[header_name] = header_value
|
||||
|
||||
# Add injected context headers from request.state
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
for header_name, header_value in request.state.injected_headers.items():
|
||||
forwardable_headers[header_name] = header_value
|
||||
|
||||
# Add authorization header if present
|
||||
auth_header = request.headers.get('authorization')
|
||||
if auth_header:
|
||||
forwardable_headers['authorization'] = auth_header
|
||||
|
||||
return forwardable_headers
|
||||
|
||||
def get_all_headers_for_proxy(self, request: Request,
|
||||
additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Get complete set of headers for proxying to downstream services
|
||||
"""
|
||||
headers = self.get_forwardable_headers(request)
|
||||
|
||||
# Add any additional headers
|
||||
if additional_headers:
|
||||
headers.update(additional_headers)
|
||||
|
||||
# Remove host header as it will be set by httpx
|
||||
headers.pop('host', None)
|
||||
|
||||
return headers
|
||||
|
||||
def validate_required_headers(self, request: Request, required_headers: List[str]) -> bool:
|
||||
"""
|
||||
Validate that required headers are present
|
||||
"""
|
||||
missing_headers = []
|
||||
|
||||
for header_name in required_headers:
|
||||
# Check in injected headers first
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
if header_name in request.state.injected_headers:
|
||||
continue
|
||||
|
||||
# Check in request headers
|
||||
if request.headers.get(header_name):
|
||||
continue
|
||||
|
||||
missing_headers.append(header_name)
|
||||
|
||||
if missing_headers:
|
||||
logger.warning(f"Missing required headers: {missing_headers}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_header_value(self, request: Request, header_name: str,
|
||||
default: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Get header value from either injected headers or request headers
|
||||
"""
|
||||
# Check injected headers first
|
||||
if hasattr(request.state, 'injected_headers'):
|
||||
if header_name in request.state.injected_headers:
|
||||
return request.state.injected_headers[header_name]
|
||||
|
||||
# Check request headers
|
||||
return request.headers.get(header_name, default)
|
||||
|
||||
def add_header_for_middleware(self, request: Request, header_name: str, header_value: str) -> None:
|
||||
"""
|
||||
Allow middleware to add headers to the unified header system
|
||||
This ensures all headers are available for proxying
|
||||
"""
|
||||
# Ensure injected_headers exists
|
||||
if not hasattr(request.state, 'injected_headers'):
|
||||
request.state.injected_headers = {}
|
||||
|
||||
# Add header to injected_headers
|
||||
request.state.injected_headers[header_name] = header_value
|
||||
|
||||
# Also add to actual request headers for compatibility
|
||||
try:
|
||||
request.headers.__dict__["_list"].append((header_name.encode(), header_value.encode()))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add header {header_name} to request headers: {e}")
|
||||
|
||||
logger.debug(f"Middleware added header: {header_name} = {header_value}")
|
||||
|
||||
|
||||
# Global instance for easy access
|
||||
header_manager = HeaderManager()
|
||||
65
gateway/app/core/service_discovery.py
Normal file
65
gateway/app/core/service_discovery.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
Service discovery for API Gateway
|
||||
"""
|
||||
|
||||
import logging
|
||||
import httpx
|
||||
from typing import Optional, Dict
|
||||
import json
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ServiceDiscovery:
|
||||
"""Service discovery client"""
|
||||
|
||||
def __init__(self):
|
||||
self.consul_url = settings.CONSUL_URL if hasattr(settings, 'CONSUL_URL') else None
|
||||
self.service_cache: Dict[str, str] = {}
|
||||
|
||||
async def get_service_url(self, service_name: str) -> Optional[str]:
|
||||
"""Get service URL from service discovery"""
|
||||
|
||||
# Return cached URL if available
|
||||
if service_name in self.service_cache:
|
||||
return self.service_cache[service_name]
|
||||
|
||||
# Try Consul if enabled
|
||||
if self.consul_url and getattr(settings, 'ENABLE_SERVICE_DISCOVERY', False):
|
||||
try:
|
||||
url = await self._get_from_consul(service_name)
|
||||
if url:
|
||||
self.service_cache[service_name] = url
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get {service_name} from Consul: {e}")
|
||||
|
||||
# Fall back to environment variables
|
||||
return self._get_from_env(service_name)
|
||||
|
||||
async def _get_from_consul(self, service_name: str) -> Optional[str]:
|
||||
"""Get service URL from Consul"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.consul_url}/v1/health/service/{service_name}?passing=true"
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
services = response.json()
|
||||
if services:
|
||||
service = services[0]
|
||||
address = service['Service']['Address']
|
||||
port = service['Service']['Port']
|
||||
return f"http://{address}:{port}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Consul query failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _get_from_env(self, service_name: str) -> Optional[str]:
|
||||
"""Get service URL from environment variables"""
|
||||
env_var = f"{service_name.upper().replace('-', '_')}_SERVICE_URL"
|
||||
return getattr(settings, env_var, None)
|
||||
Reference in New Issue
Block a user