Files
bakery-ia/gateway/app/middleware/request_id.py
2026-01-12 22:15:11 +01:00

84 lines
2.5 KiB
Python

"""
Request ID Middleware for distributed tracing
Generates and propagates unique request IDs across all services
"""
import uuid
import structlog
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from app.core.header_manager import header_manager
logger = structlog.get_logger()
class RequestIDMiddleware(BaseHTTPMiddleware):
"""
Middleware to generate and propagate request IDs for distributed tracing.
Request IDs are:
- Generated if not provided by client
- Logged with every request
- Propagated to all downstream services
- Returned in response headers
"""
def __init__(self, app):
super().__init__(app)
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with request ID tracking"""
# Extract or generate request ID
request_id = request.headers.get("X-Request-ID")
if not request_id:
request_id = str(uuid.uuid4())
# Store in request state for access by routes
request.state.request_id = request_id
# Bind request ID to structured logger context
logger_ctx = logger.bind(request_id=request_id)
# Inject request ID header for downstream services using HeaderManager
# Note: This runs early in middleware chain, so we use add_header_for_middleware
header_manager.add_header_for_middleware(request, "x-request-id", request_id)
# Log request start
logger_ctx.info(
"Request started",
method=request.method,
path=request.url.path,
client_ip=request.client.host if request.client else None
)
try:
# Process request
response = await call_next(request)
# Add request ID to response headers
response.headers["X-Request-ID"] = request_id
# Log request completion
logger_ctx.info(
"Request completed",
method=request.method,
path=request.url.path,
status_code=response.status_code
)
return response
except Exception as e:
# Log request failure
logger_ctx.error(
"Request failed",
method=request.method,
path=request.url.path,
error=str(e),
error_type=type(e).__name__
)
raise