Fix gateway

This commit is contained in:
Urtzi Alfaro
2025-07-17 19:54:04 +02:00
parent caf7dea73a
commit 654d1c2fe8
4 changed files with 203 additions and 207 deletions

View File

@@ -14,9 +14,9 @@ from typing import Dict, Any
from app.core.config import settings from app.core.config import settings
from app.core.service_discovery import ServiceDiscovery from app.core.service_discovery import ServiceDiscovery
from app.middleware.auth import auth_middleware from app.middleware.auth import AuthMiddleware
from app.middleware.logging import logging_middleware from app.middleware.logging import LoggingMiddleware
from app.middleware.rate_limit import rate_limit_middleware from app.middleware.rate_limit import RateLimitMiddleware
from app.routes import auth, training, forecasting, data, tenant, notification from app.routes import auth, training, forecasting, data, tenant, notification
from shared.monitoring.logging import setup_logging from shared.monitoring.logging import setup_logging
from shared.monitoring.metrics import MetricsCollector from shared.monitoring.metrics import MetricsCollector
@@ -40,7 +40,7 @@ metrics_collector = MetricsCollector("gateway")
# Service discovery # Service discovery
service_discovery = ServiceDiscovery() service_discovery = ServiceDiscovery()
# CORS middleware - FIXED: Use the parsed list property # CORS middleware - Add first
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=settings.CORS_ORIGINS_LIST, allow_origins=settings.CORS_ORIGINS_LIST,
@@ -49,10 +49,10 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# Custom middleware # Custom middleware - Add in correct order (outer to inner)
app.add_middleware(auth_middleware) app.add_middleware(LoggingMiddleware)
app.add_middleware(logging_middleware) app.add_middleware(RateLimitMiddleware, calls_per_minute=60)
app.add_middleware(rate_limit_middleware) app.add_middleware(AuthMiddleware)
# Include routers # Include routers
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"]) app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
@@ -88,43 +88,17 @@ async def shutdown_event():
@app.get("/health") @app.get("/health")
async def health_check(): async def health_check():
"""Health check endpoint""" """Health check endpoint"""
healthy_services = await service_discovery.get_healthy_services()
return { return {
"status": "healthy", "status": "healthy",
"service": "gateway", "service": "api-gateway",
"version": "1.0.0", "version": "1.0.0",
"healthy_services": healthy_services,
"total_services": len(settings.SERVICES),
"timestamp": time.time() "timestamp": time.time()
} }
@app.get("/metrics") @app.get("/metrics")
async def get_metrics(): async def metrics():
"""Get basic metrics""" """Metrics endpoint for monitoring"""
return { return {"metrics": "enabled"}
"service": "gateway",
"uptime": time.time() - app.state.start_time if hasattr(app.state, 'start_time') else 0,
"healthy_services": await service_discovery.get_healthy_services()
}
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""Handle HTTP exceptions"""
logger.error(f"HTTP {exc.status_code}: {exc.detail}")
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail, "service": "gateway"}
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""Handle general exceptions"""
logger.error(f"Unhandled exception: {exc}", exc_info=True)
return JSONResponse(
status_code=500,
content={"detail": "Internal server error", "service": "gateway"}
)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn

View File

@@ -3,8 +3,10 @@ Authentication middleware for gateway
""" """
import logging import logging
from fastapi import Request, HTTPException from fastapi import Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
import httpx import httpx
from typing import Optional from typing import Optional
@@ -28,74 +30,77 @@ PUBLIC_ROUTES = [
"/api/v1/auth/refresh" "/api/v1/auth/refresh"
] ]
async def auth_middleware(request: Request, call_next): class AuthMiddleware(BaseHTTPMiddleware):
"""Authentication middleware""" """Authentication middleware class"""
# Check if route requires authentication async def dispatch(self, request: Request, call_next) -> Response:
if _is_public_route(request.url.path): """Process request with authentication"""
return await call_next(request)
# Get token from header # Check if route requires authentication
token = _extract_token(request) if self._is_public_route(request.url.path):
if not token:
return JSONResponse(
status_code=401,
content={"detail": "Authentication required"}
)
# Verify token
try:
# First try to verify token locally
payload = jwt_handler.verify_token(token)
if payload:
# Add user info to request state
request.state.user = payload
return await call_next(request) return await call_next(request)
else:
# Token invalid or expired, verify with auth service
user_info = await _verify_with_auth_service(token)
if user_info:
request.state.user = user_info
return await call_next(request)
else:
return JSONResponse(
status_code=401,
content={"detail": "Invalid or expired token"}
)
except Exception as e: # Get token from header
logger.error(f"Authentication error: {e}") token = self._extract_token(request)
return JSONResponse( if not token:
status_code=401, return JSONResponse(
content={"detail": "Authentication failed"} status_code=401,
) content={"detail": "Authentication required"}
def _is_public_route(path: str) -> bool:
"""Check if route is public"""
return any(path.startswith(route) for route in PUBLIC_ROUTES)
def _extract_token(request: Request) -> Optional[str]:
"""Extract JWT token from request"""
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.split(" ")[1]
return None
async def _verify_with_auth_service(token: str) -> Optional[dict]:
"""Verify token with auth service"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.post(
f"{settings.AUTH_SERVICE_URL}/verify",
headers={"Authorization": f"Bearer {token}"}
) )
if response.status_code == 200: # Verify token
return response.json() try:
else: # First try to verify token locally
return None payload = jwt_handler.verify_token(token)
except Exception as e: if payload:
logger.error(f"Auth service verification failed: {e}") # Add user info to request state
request.state.user = payload
return await call_next(request)
else:
# Token invalid or expired, verify with auth service
user_info = await self._verify_with_auth_service(token)
if user_info:
request.state.user = user_info
return await call_next(request)
else:
return JSONResponse(
status_code=401,
content={"detail": "Invalid or expired token"}
)
except Exception as e:
logger.error(f"Authentication error: {e}")
return JSONResponse(
status_code=401,
content={"detail": "Authentication failed"}
)
def _is_public_route(self, path: str) -> bool:
"""Check if route is public"""
return any(path.startswith(route) for route in PUBLIC_ROUTES)
def _extract_token(self, request: Request) -> Optional[str]:
"""Extract JWT token from request"""
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.split(" ")[1]
return None return None
async def _verify_with_auth_service(self, token: str) -> Optional[dict]:
"""Verify token with auth service"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.post(
f"{settings.AUTH_SERVICE_URL}/verify",
headers={"Authorization": f"Bearer {token}"}
)
if response.status_code == 200:
return response.json()
else:
return None
except Exception as e:
logger.error(f"Auth service verification failed: {e}")
return None

View File

@@ -5,44 +5,53 @@ Logging middleware for gateway
import logging import logging
import time import time
from fastapi import Request from fastapi import Request
import json from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
import uuid
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def logging_middleware(request: Request, call_next): class LoggingMiddleware(BaseHTTPMiddleware):
"""Logging middleware""" """Logging middleware class"""
start_time = time.time() async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with logging"""
# Log request start_time = time.time()
logger.info(
f"Request: {request.method} {request.url.path}",
extra={
"method": request.method,
"url": request.url.path,
"query_params": str(request.query_params),
"client_host": request.client.host,
"user_agent": request.headers.get("user-agent", ""),
"request_id": getattr(request.state, 'request_id', None)
}
)
# Process request # Generate request ID
response = await call_next(request) request_id = str(uuid.uuid4())
request.state.request_id = request_id
# Calculate duration # Log request
duration = time.time() - start_time logger.info(
f"Request: {request.method} {request.url.path}",
extra={
"method": request.method,
"url": request.url.path,
"query_params": str(request.query_params),
"client_host": request.client.host if request.client else "unknown",
"user_agent": request.headers.get("user-agent", ""),
"request_id": request_id
}
)
# Log response # Process request
logger.info( response = await call_next(request)
f"Response: {response.status_code} in {duration:.3f}s",
extra={
"status_code": response.status_code,
"duration": duration,
"method": request.method,
"url": request.url.path,
"request_id": getattr(request.state, 'request_id', None)
}
)
return response # Calculate duration
duration = time.time() - start_time
# Log response
logger.info(
f"Response: {response.status_code} in {duration:.3f}s",
extra={
"status_code": response.status_code,
"duration": duration,
"method": request.method,
"url": request.url.path,
"request_id": request_id
}
)
return response

View File

@@ -3,83 +3,91 @@ Rate limiting middleware for gateway
""" """
import logging import logging
from fastapi import Request, HTTPException import time
from fastapi import Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
import redis.asyncio as redis from starlette.middleware.base import BaseHTTPMiddleware
from datetime import datetime, timedelta from starlette.responses import Response
import hashlib from typing import Dict, Optional
import asyncio
from app.core.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Redis client for rate limiting class RateLimitMiddleware(BaseHTTPMiddleware):
redis_client = redis.from_url(settings.REDIS_URL) """Rate limiting middleware class"""
async def rate_limit_middleware(request: Request, call_next): def __init__(self, app, calls_per_minute: int = 60):
"""Rate limiting middleware""" super().__init__(app)
self.calls_per_minute = calls_per_minute
self.requests: Dict[str, list] = {}
self._cleanup_task = None
# Skip rate limiting for health checks async def dispatch(self, request: Request, call_next) -> Response:
if request.url.path in ["/health", "/metrics"]: """Process request with rate limiting"""
# Skip rate limiting for health checks
if request.url.path in ["/health", "/metrics"]:
return await call_next(request)
# Get client identifier
client_id = self._get_client_id(request)
# Check rate limit
if self._is_rate_limited(client_id):
logger.warning(f"Rate limit exceeded for client: {client_id}")
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"}
)
# Record request
self._record_request(client_id)
# Process request
return await call_next(request) return await call_next(request)
# Get client identifier (IP address or user ID) def _get_client_id(self, request: Request) -> str:
client_id = _get_client_id(request) """Get client identifier"""
# Try to get user ID from state (if authenticated)
if hasattr(request.state, 'user') and request.state.user:
return f"user:{request.state.user.get('user_id', 'unknown')}"
# Check rate limit # Fall back to IP address
if await _is_rate_limited(client_id): return f"ip:{request.client.host if request.client else 'unknown'}"
return JSONResponse(
status_code=429,
content={
"detail": "Rate limit exceeded",
"retry_after": settings.RATE_LIMIT_WINDOW
}
)
# Process request def _is_rate_limited(self, client_id: str) -> bool:
response = await call_next(request) """Check if client is rate limited"""
now = time.time()
minute_ago = now - 60
# Update rate limit counter # Get recent requests for this client
await _update_rate_limit(client_id) if client_id not in self.requests:
return response
def _get_client_id(request: Request) -> str:
"""Get client identifier for rate limiting"""
# Use user ID if authenticated, otherwise use IP
if hasattr(request.state, 'user') and request.state.user:
return f"user:{request.state.user.get('user_id', 'unknown')}"
else:
# Hash IP address for privacy
ip = request.client.host
return f"ip:{hashlib.md5(ip.encode()).hexdigest()}"
async def _is_rate_limited(client_id: str) -> bool:
"""Check if client is rate limited"""
try:
key = f"rate_limit:{client_id}"
current_count = await redis_client.get(key)
if current_count is None:
return False return False
return int(current_count) >= settings.RATE_LIMIT_REQUESTS # Filter requests from last minute
recent_requests = [
req_time for req_time in self.requests[client_id]
if req_time > minute_ago
]
except Exception as e: # Update the list
logger.error(f"Rate limit check failed: {e}") self.requests[client_id] = recent_requests
return False
async def _update_rate_limit(client_id: str): # Check if limit exceeded
"""Update rate limit counter""" return len(recent_requests) >= self.calls_per_minute
try:
key = f"rate_limit:{client_id}"
# Increment counter def _record_request(self, client_id: str):
current_count = await redis_client.incr(key) """Record a request for rate limiting"""
now = time.time()
# Set TTL on first request if client_id not in self.requests:
if current_count == 1: self.requests[client_id] = []
await redis_client.expire(key, settings.RATE_LIMIT_WINDOW)
except Exception as e: self.requests[client_id].append(now)
logger.error(f"Rate limit update failed: {e}")
# Keep only last minute of requests
minute_ago = now - 60
self.requests[client_id] = [
req_time for req_time in self.requests[client_id]
if req_time > minute_ago
]