Files
bakery-ia/gateway/app/routes/auth.py
2026-01-12 14:24:14 +01:00

310 lines
14 KiB
Python

# ================================================================
# gateway/app/routes/auth.py
# ================================================================
"""
Authentication and User Management Routes for API Gateway
Unified proxy to auth microservice
"""
import logging
import httpx
from fastapi import APIRouter, Request, Response, HTTPException, status
from fastapi.responses import JSONResponse
from typing import Dict, Any
from app.core.config import settings
from app.core.service_discovery import ServiceDiscovery
from shared.monitoring.metrics import MetricsCollector
logger = logging.getLogger(__name__)
router = APIRouter()
# Initialize service discovery and metrics
service_discovery = ServiceDiscovery()
metrics = MetricsCollector("gateway")
# Auth service configuration
AUTH_SERVICE_URL = settings.AUTH_SERVICE_URL or "http://auth-service:8000"
class AuthProxy:
"""Authentication service proxy with enhanced error handling"""
def __init__(self):
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(30.0),
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20)
)
async def forward_request(
self,
method: str,
path: str,
request: Request
) -> Response:
"""Forward request to auth service with proper error handling"""
# Handle OPTIONS requests directly for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
}
)
try:
# Get auth service URL (with service discovery if available)
auth_url = await self._get_auth_service_url()
target_url = f"{auth_url}/{path}"
# Prepare headers (remove hop-by-hop headers)
# IMPORTANT: Use request.headers directly to get headers added by middleware
# Also check request.state for headers injected by middleware
headers = self._prepare_headers(request.headers, request)
# Get request body
body = await request.body()
# Forward request
logger.info(f"Forwarding {method} /{path} to auth service")
response = await self.client.request(
method=method,
url=target_url,
headers=headers,
content=body,
params=dict(request.query_params)
)
# Record metrics
metrics.increment_counter("gateway_auth_requests_total")
metrics.increment_counter(
"gateway_auth_responses_total",
labels={"status_code": str(response.status_code)}
)
# Prepare response headers
response_headers = self._prepare_response_headers(dict(response.headers))
return Response(
content=response.content,
status_code=response.status_code,
headers=response_headers,
media_type=response.headers.get("content-type")
)
except httpx.TimeoutException:
logger.error(f"Timeout forwarding {method} /{path} to auth service")
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "timeout"})
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail="Authentication service timeout"
)
except httpx.ConnectError:
logger.error(f"Connection error forwarding {method} /{path} to auth service")
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "connection"})
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Authentication service unavailable"
)
except Exception as e:
logger.error(f"Error forwarding {method} /{path} to auth service: {e}")
metrics.increment_counter("gateway_auth_errors_total", labels={"error": "unknown"})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal gateway error"
)
async def _get_auth_service_url(self) -> str:
"""Get auth service URL with service discovery"""
try:
# Try service discovery first
service_url = await service_discovery.get_service_url("auth-service")
if service_url:
return service_url
except Exception as e:
logger.warning(f"Service discovery failed: {e}")
# Fall back to configured URL
return AUTH_SERVICE_URL
def _prepare_headers(self, headers, request=None) -> Dict[str, str]:
"""Prepare headers for forwarding (remove hop-by-hop headers)"""
# Remove hop-by-hop headers
hop_by_hop_headers = {
'connection', 'keep-alive', 'proxy-authenticate',
'proxy-authorization', 'te', 'trailers', 'upgrade'
}
# Convert headers to dict - get ALL headers including those added by middleware
# Middleware adds headers to _list, so we need to read from there
logger.debug(f"DEBUG: headers type: {type(headers)}, has _list: {hasattr(headers, '_list')}, has raw: {hasattr(headers, 'raw')}")
logger.debug(f"DEBUG: headers.__dict__ keys: {list(headers.__dict__.keys())}")
logger.debug(f"DEBUG: '_list' in headers.__dict__: {'_list' in headers.__dict__}")
if hasattr(headers, '_list'):
logger.debug(f"DEBUG: Entering _list branch")
logger.debug(f"DEBUG: headers object id: {id(headers)}, _list id: {id(headers.__dict__.get('_list', []))}")
# Get headers from the _list where middleware adds them
all_headers_list = headers.__dict__.get('_list', [])
logger.debug(f"DEBUG: _list length: {len(all_headers_list)}")
# Debug: Show first few headers in the list
debug_headers = []
for i, (k, v) in enumerate(all_headers_list):
if i < 5: # Show first 5 headers for debugging
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
debug_headers.append(f"{key}: {value}")
logger.debug(f"DEBUG: First headers in _list: {debug_headers}")
# Convert to dict for easier processing
all_headers = {}
for k, v in all_headers_list:
key = k.decode() if isinstance(k, bytes) else k
value = v.decode() if isinstance(v, bytes) else v
all_headers[key] = value
# Debug: Show if x-user-id and x-is-demo are in the dict
logger.debug(f"DEBUG: x-user-id in all_headers: {'x-user-id' in all_headers}, x-is-demo in all_headers: {'x-is-demo' in all_headers}")
logger.debug(f"DEBUG: all_headers keys: {list(all_headers.keys())[:10]}...") # Show first 10 keys
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
# Check if headers are missing and try to get them from request.state
if request and hasattr(request, 'state') and hasattr(request.state, 'injected_headers'):
logger.debug(f"DEBUG: Found injected_headers in request.state: {request.state.injected_headers}")
# Add missing headers from request.state
if 'x-user-id' not in all_headers and 'x-user-id' in request.state.injected_headers:
all_headers['x-user-id'] = request.state.injected_headers['x-user-id']
logger.debug(f"DEBUG: Added x-user-id from request.state: {all_headers['x-user-id']}")
if 'x-user-email' not in all_headers and 'x-user-email' in request.state.injected_headers:
all_headers['x-user-email'] = request.state.injected_headers['x-user-email']
logger.debug(f"DEBUG: Added x-user-email from request.state: {all_headers['x-user-email']}")
if 'x-user-role' not in all_headers and 'x-user-role' in request.state.injected_headers:
all_headers['x-user-role'] = request.state.injected_headers['x-user-role']
logger.debug(f"DEBUG: Added x-user-role from request.state: {all_headers['x-user-role']}")
# Add is_demo flag if this is a demo session
if hasattr(request.state, 'is_demo_session') and request.state.is_demo_session:
all_headers['x-is-demo'] = 'true'
logger.debug(f"DEBUG: Added x-is-demo from request.state.is_demo_session")
# Filter out hop-by-hop headers
filtered_headers = {
k: v for k, v in all_headers.items()
if k.lower() not in hop_by_hop_headers
}
elif hasattr(headers, 'raw'):
logger.debug(f"DEBUG: Entering raw branch")
# Filter out hop-by-hop headers
filtered_headers = {
k: v for k, v in all_headers.items()
if k.lower() not in hop_by_hop_headers
}
elif hasattr(headers, 'raw'):
# Fallback to raw headers if _list not available
all_headers = {
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
for k, v in headers.raw
}
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {all_headers.get('x-user-id', 'MISSING')}, x_is_demo: {all_headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {all_headers.get('x-demo-session-id', 'MISSING')}, headers: {list(all_headers.keys())}")
filtered_headers = {
k.decode() if isinstance(k, bytes) else k: v.decode() if isinstance(v, bytes) else v
for k, v in headers.raw
if (k.decode() if isinstance(k, bytes) else k).lower() not in hop_by_hop_headers
}
else:
# Handle case where headers is already a dict
logger.info(f"📤 Forwarding headers to auth service - x_user_id: {headers.get('x-user-id', 'MISSING')}, x_is_demo: {headers.get('x-is-demo', 'MISSING')}, x_demo_session_id: {headers.get('x-demo-session-id', 'MISSING')}, headers: {list(headers.keys())}")
filtered_headers = {
k: v for k, v in headers.items()
if k.lower() not in hop_by_hop_headers
}
# Add gateway identifier
filtered_headers['X-Forwarded-By'] = 'bakery-gateway'
filtered_headers['X-Gateway-Version'] = '1.0.0'
return filtered_headers
def _prepare_response_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
"""Prepare response headers"""
# Remove server-specific headers
filtered_headers = {
k: v for k, v in headers.items()
if k.lower() not in {'server', 'date'}
}
# Add CORS headers if needed
if settings.CORS_ORIGINS:
filtered_headers['Access-Control-Allow-Origin'] = '*'
filtered_headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
filtered_headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
return filtered_headers
# Initialize proxy
auth_proxy = AuthProxy()
# ================================================================
# CATCH-ALL ROUTE for all auth and user endpoints
# ================================================================
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
async def proxy_auth_requests(path: str, request: Request):
"""Catch-all proxy for all auth and user requests"""
return await auth_proxy.forward_request(request.method, f"api/v1/auth/{path}", request)
# ================================================================
# HEALTH CHECK for auth service
# ================================================================
@router.get("/health")
async def auth_service_health():
"""Check auth service health"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{AUTH_SERVICE_URL}/health")
if response.status_code == 200:
return {
"status": "healthy",
"auth_service": "available",
"response_time_ms": response.elapsed.total_seconds() * 1000
}
else:
return {
"status": "unhealthy",
"auth_service": "error",
"status_code": response.status_code
}
except Exception as e:
logger.error(f"Auth service health check failed: {e}")
return {
"status": "unhealthy",
"auth_service": "unavailable",
"error": str(e)
}
# ================================================================
# CLEANUP
# ================================================================
@router.on_event("shutdown")
async def cleanup():
"""Cleanup resources"""
await auth_proxy.client.aclose()