Add role-based filtering and imporve code

This commit is contained in:
Urtzi Alfaro
2025-10-15 16:12:49 +02:00
parent 96ad5c6692
commit 8f9e9a7edc
158 changed files with 11033 additions and 1544 deletions

View File

@@ -13,18 +13,18 @@ logger = logging.getLogger(__name__)
class ServiceDiscovery:
"""Service discovery client"""
def __init__(self):
self.consul_url = settings.CONSUL_URL if hasattr(settings, 'CONSUL_URL') else None
self.service_cache: Dict[str, str] = {}
async def get_service_url(self, service_name: str) -> Optional[str]:
"""Get service URL from service discovery"""
# Return cached URL if available
if service_name in self.service_cache:
return self.service_cache[service_name]
# Try Consul if enabled
if self.consul_url and getattr(settings, 'ENABLE_SERVICE_DISCOVERY', False):
try:
@@ -34,10 +34,10 @@ class ServiceDiscovery:
return url
except Exception as e:
logger.warning(f"Failed to get {service_name} from Consul: {e}")
# Fall back to environment variables
return self._get_from_env(service_name)
async def _get_from_consul(self, service_name: str) -> Optional[str]:
"""Get service URL from Consul"""
try:
@@ -45,7 +45,7 @@ class ServiceDiscovery:
response = await client.get(
f"{self.consul_url}/v1/health/service/{service_name}?passing=true"
)
if response.status_code == 200:
services = response.json()
if services:
@@ -53,13 +53,13 @@ class ServiceDiscovery:
address = service['Service']['Address']
port = service['Service']['Port']
return f"http://{address}:{port}"
except Exception as e:
logger.error(f"Consul query failed: {e}")
return None
def _get_from_env(self, service_name: str) -> Optional[str]:
"""Get service URL from environment variables"""
env_var = f"{service_name.upper().replace('-', '_')}_SERVICE_URL"
return getattr(settings, env_var, None)
return getattr(settings, env_var, None)

View File

@@ -11,11 +11,11 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
import httpx
import time
import redis.asyncio as aioredis
from shared.redis_utils import initialize_redis, close_redis, get_redis_client
from typing import Dict, Any
from app.core.config import settings
from app.core.service_discovery import ServiceDiscovery
from app.middleware.request_id import RequestIDMiddleware
from app.middleware.auth import AuthMiddleware
from app.middleware.logging import LoggingMiddleware
from app.middleware.rate_limit import RateLimitMiddleware
@@ -41,9 +41,6 @@ app = FastAPI(
# Initialize metrics collector
metrics_collector = MetricsCollector("gateway")
# Service discovery
service_discovery = ServiceDiscovery()
# Redis client for SSE streaming
redis_client = None
@@ -57,12 +54,13 @@ app.add_middleware(
)
# Custom middleware - Add in REVERSE order (last added = first executed)
# Execution order: DemoMiddleware -> AuthMiddleware -> SubscriptionMiddleware -> RateLimitMiddleware -> LoggingMiddleware
app.add_middleware(LoggingMiddleware) # Executes 5th (outermost)
app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 4th
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 3rd
app.add_middleware(AuthMiddleware) # Executes 2nd - Checks for demo context
app.add_middleware(DemoMiddleware) # Executes 1st (innermost) - Sets demo user context FIRST
# Execution order: RequestIDMiddleware -> DemoMiddleware -> AuthMiddleware -> SubscriptionMiddleware -> RateLimitMiddleware -> LoggingMiddleware
app.add_middleware(LoggingMiddleware) # Executes 6th (outermost)
app.add_middleware(RateLimitMiddleware, calls_per_minute=300) # Executes 5th
app.add_middleware(SubscriptionMiddleware, tenant_service_url=settings.TENANT_SERVICE_URL) # Executes 4th
app.add_middleware(AuthMiddleware) # Executes 3rd - Checks for demo context
app.add_middleware(DemoMiddleware) # Executes 2nd - Sets demo user context
app.add_middleware(RequestIDMiddleware) # Executes 1st (innermost) - Generates request ID for tracing
# Include routers
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
@@ -79,12 +77,13 @@ app.include_router(demo.router, prefix="/api/v1", tags=["demo"])
async def startup_event():
"""Application startup"""
global redis_client
logger.info("Starting API Gateway")
# Connect to Redis for SSE streaming
# Initialize shared Redis connection
try:
redis_client = aioredis.from_url(settings.REDIS_URL)
await initialize_redis(settings.REDIS_URL, db=0, max_connections=50)
redis_client = await get_redis_client()
logger.info("Connected to Redis for SSE streaming")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
@@ -116,17 +115,14 @@ async def startup_event():
@app.on_event("shutdown")
async def shutdown_event():
"""Application shutdown"""
global redis_client
logger.info("Shutting down API Gateway")
# Close Redis connection
if redis_client:
await redis_client.close()
# Close shared Redis connection
await close_redis()
# Clean up service discovery
# await service_discovery.cleanup()
logger.info("API Gateway shutdown complete")
@app.get("/health")

View File

@@ -0,0 +1,83 @@
"""
Request ID Middleware for distributed tracing
Generates and propagates unique request IDs across all services
"""
import uuid
import structlog
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
logger = structlog.get_logger()
class RequestIDMiddleware(BaseHTTPMiddleware):
"""
Middleware to generate and propagate request IDs for distributed tracing.
Request IDs are:
- Generated if not provided by client
- Logged with every request
- Propagated to all downstream services
- Returned in response headers
"""
def __init__(self, app):
super().__init__(app)
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with request ID tracking"""
# Extract or generate request ID
request_id = request.headers.get("X-Request-ID")
if not request_id:
request_id = str(uuid.uuid4())
# Store in request state for access by routes
request.state.request_id = request_id
# Bind request ID to structured logger context
logger_ctx = logger.bind(request_id=request_id)
# Inject request ID header for downstream services
# This is done by modifying the headers that will be forwarded
request.headers.__dict__["_list"].append((
b"x-request-id", request_id.encode()
))
# Log request start
logger_ctx.info(
"Request started",
method=request.method,
path=request.url.path,
client_ip=request.client.host if request.client else None
)
try:
# Process request
response = await call_next(request)
# Add request ID to response headers
response.headers["X-Request-ID"] = request_id
# Log request completion
logger_ctx.info(
"Request completed",
method=request.method,
path=request.url.path,
status_code=response.status_code
)
return response
except Exception as e:
# Log request failure
logger_ctx.error(
"Request failed",
method=request.method,
path=request.url.path,
error=str(e),
error_type=type(e).__name__
)
raise

View File

@@ -26,13 +26,13 @@ async def proxy_subscription_endpoints(request: Request, tenant_id: str = Path(.
@router.api_route("/subscriptions/plans", methods=["GET", "OPTIONS"])
async def proxy_subscription_plans(request: Request):
"""Proxy subscription plans request to tenant service"""
target_path = "/api/v1/plans"
target_path = "/plans"
return await _proxy_to_tenant_service(request, target_path)
@router.api_route("/plans", methods=["GET", "OPTIONS"])
async def proxy_plans(request: Request):
"""Proxy plans request to tenant service"""
target_path = "/api/v1/plans"
target_path = "/plans"
return await _proxy_to_tenant_service(request, target_path)
# ================================================================