Add role-based filtering and imporve code
This commit is contained in:
@@ -13,18 +13,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ServiceDiscovery:
|
||||
"""Service discovery client"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.consul_url = settings.CONSUL_URL if hasattr(settings, 'CONSUL_URL') else None
|
||||
self.service_cache: Dict[str, str] = {}
|
||||
|
||||
|
||||
async def get_service_url(self, service_name: str) -> Optional[str]:
|
||||
"""Get service URL from service discovery"""
|
||||
|
||||
|
||||
# Return cached URL if available
|
||||
if service_name in self.service_cache:
|
||||
return self.service_cache[service_name]
|
||||
|
||||
|
||||
# Try Consul if enabled
|
||||
if self.consul_url and getattr(settings, 'ENABLE_SERVICE_DISCOVERY', False):
|
||||
try:
|
||||
@@ -34,10 +34,10 @@ class ServiceDiscovery:
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get {service_name} from Consul: {e}")
|
||||
|
||||
|
||||
# Fall back to environment variables
|
||||
return self._get_from_env(service_name)
|
||||
|
||||
|
||||
async def _get_from_consul(self, service_name: str) -> Optional[str]:
|
||||
"""Get service URL from Consul"""
|
||||
try:
|
||||
@@ -45,7 +45,7 @@ class ServiceDiscovery:
|
||||
response = await client.get(
|
||||
f"{self.consul_url}/v1/health/service/{service_name}?passing=true"
|
||||
)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
services = response.json()
|
||||
if services:
|
||||
@@ -53,13 +53,13 @@ class ServiceDiscovery:
|
||||
address = service['Service']['Address']
|
||||
port = service['Service']['Port']
|
||||
return f"http://{address}:{port}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Consul query failed: {e}")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_from_env(self, service_name: str) -> Optional[str]:
|
||||
"""Get service URL from environment variables"""
|
||||
env_var = f"{service_name.upper().replace('-', '_')}_SERVICE_URL"
|
||||
return getattr(settings, env_var, None)
|
||||
return getattr(settings, env_var, None)
|
||||
|
||||
@@ -11,11 +11,11 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
import httpx
|
||||
import time
|
||||
import redis.asyncio as aioredis
|
||||
from shared.redis_utils import initialize_redis, close_redis, get_redis_client
|
||||
from typing import Dict, Any
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.service_discovery import ServiceDiscovery
|
||||
from app.middleware.request_id import RequestIDMiddleware
|
||||
from app.middleware.auth import AuthMiddleware
|
||||
from app.middleware.logging import LoggingMiddleware
|
||||
from app.middleware.rate_limit import RateLimitMiddleware
|
||||
@@ -41,9 +41,6 @@ app = FastAPI(
|
||||
# Initialize metrics collector
|
||||
metrics_collector = MetricsCollector("gateway")
|
||||
|
||||
# Service discovery
|
||||
service_discovery = ServiceDiscovery()
|
||||
|
||||
# Redis client for SSE streaming
|
||||
redis_client = None
|
||||
|
||||
@@ -57,12 +54,13 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
# Custom middleware - Add in REVERSE order (last added = first executed)
|
||||
# Execution order: DemoMiddleware -> AuthMiddleware -> SubscriptionMiddleware -> RateLimitMiddleware -> LoggingMiddleware
|
||||
app.add_middleware(LoggingMiddleware) # Executes 5th (outermost)
|
||||
app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 4th
|
||||
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 3rd
|
||||
app.add_middleware(AuthMiddleware) # Executes 2nd - Checks for demo context
|
||||
app.add_middleware(DemoMiddleware) # Executes 1st (innermost) - Sets demo user context FIRST
|
||||
# Execution order: RequestIDMiddleware -> DemoMiddleware -> AuthMiddleware -> SubscriptionMiddleware -> RateLimitMiddleware -> LoggingMiddleware
|
||||
app.add_middleware(LoggingMiddleware) # Executes 6th (outermost)
|
||||
app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 5th
|
||||
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 4th
|
||||
app.add_middleware(AuthMiddleware) # Executes 3rd - Checks for demo context
|
||||
app.add_middleware(DemoMiddleware) # Executes 2nd - Sets demo user context
|
||||
app.add_middleware(RequestIDMiddleware) # Executes 1st (innermost) - Generates request ID for tracing
|
||||
|
||||
# Include routers
|
||||
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
|
||||
@@ -79,12 +77,13 @@ app.include_router(demo.router, prefix="/api/v1", tags=["demo"])
|
||||
async def startup_event():
|
||||
"""Application startup"""
|
||||
global redis_client
|
||||
|
||||
|
||||
logger.info("Starting API Gateway")
|
||||
|
||||
# Connect to Redis for SSE streaming
|
||||
|
||||
# Initialize shared Redis connection
|
||||
try:
|
||||
redis_client = aioredis.from_url(settings.REDIS_URL)
|
||||
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
|
||||
redis_client = await get_redis_client()
|
||||
logger.info("Connected to Redis for SSE streaming")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}")
|
||||
@@ -116,17 +115,14 @@ async def startup_event():
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Application shutdown"""
|
||||
global redis_client
|
||||
|
||||
logger.info("Shutting down API Gateway")
|
||||
|
||||
# Close Redis connection
|
||||
if redis_client:
|
||||
await redis_client.close()
|
||||
|
||||
|
||||
# Close shared Redis connection
|
||||
await close_redis()
|
||||
|
||||
# Clean up service discovery
|
||||
# await service_discovery.cleanup()
|
||||
|
||||
|
||||
logger.info("API Gateway shutdown complete")
|
||||
|
||||
@app.get("/health")
|
||||
|
||||
83
gateway/app/middleware/request_id.py
Normal file
83
gateway/app/middleware/request_id.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Request ID Middleware for distributed tracing
|
||||
Generates and propagates unique request IDs across all services
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import structlog
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to generate and propagate request IDs for distributed tracing.
|
||||
|
||||
Request IDs are:
|
||||
- Generated if not provided by client
|
||||
- Logged with every request
|
||||
- Propagated to all downstream services
|
||||
- Returned in response headers
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with request ID tracking"""
|
||||
|
||||
# Extract or generate request ID
|
||||
request_id = request.headers.get("X-Request-ID")
|
||||
if not request_id:
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
# Store in request state for access by routes
|
||||
request.state.request_id = request_id
|
||||
|
||||
# Bind request ID to structured logger context
|
||||
logger_ctx = logger.bind(request_id=request_id)
|
||||
|
||||
# Inject request ID header for downstream services
|
||||
# This is done by modifying the headers that will be forwarded
|
||||
request.headers.__dict__["_list"].append((
|
||||
b"x-request-id", request_id.encode()
|
||||
))
|
||||
|
||||
# Log request start
|
||||
logger_ctx.info(
|
||||
"Request started",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
client_ip=request.client.host if request.client else None
|
||||
)
|
||||
|
||||
try:
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add request ID to response headers
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
|
||||
# Log request completion
|
||||
logger_ctx.info(
|
||||
"Request completed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
status_code=response.status_code
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Log request failure
|
||||
logger_ctx.error(
|
||||
"Request failed",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
raise
|
||||
@@ -26,13 +26,13 @@ async def proxy_subscription_endpoints(request: Request, tenant_id: str = Path(.
|
||||
@router.api_route("/subscriptions/plans", methods=["GET", "OPTIONS"])
|
||||
async def proxy_subscription_plans(request: Request):
|
||||
"""Proxy subscription plans request to tenant service"""
|
||||
target_path = "/api/v1/plans"
|
||||
target_path = "/plans"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
@router.api_route("/plans", methods=["GET", "OPTIONS"])
|
||||
async def proxy_plans(request: Request):
|
||||
"""Proxy plans request to tenant service"""
|
||||
target_path = "/api/v1/plans"
|
||||
target_path = "/plans"
|
||||
return await _proxy_to_tenant_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
|
||||
Reference in New Issue
Block a user