Fix gateway
This commit is contained in:
@@ -14,9 +14,9 @@ from typing import Dict, Any
|
|||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.service_discovery import ServiceDiscovery
|
from app.core.service_discovery import ServiceDiscovery
|
||||||
from app.middleware.auth import auth_middleware
|
from app.middleware.auth import AuthMiddleware
|
||||||
from app.middleware.logging import logging_middleware
|
from app.middleware.logging import LoggingMiddleware
|
||||||
from app.middleware.rate_limit import rate_limit_middleware
|
from app.middleware.rate_limit import RateLimitMiddleware
|
||||||
from app.routes import auth, training, forecasting, data, tenant, notification
|
from app.routes import auth, training, forecasting, data, tenant, notification
|
||||||
from shared.monitoring.logging import setup_logging
|
from shared.monitoring.logging import setup_logging
|
||||||
from shared.monitoring.metrics import MetricsCollector
|
from shared.monitoring.metrics import MetricsCollector
|
||||||
@@ -40,7 +40,7 @@ metrics_collector = MetricsCollector("gateway")
|
|||||||
# Service discovery
|
# Service discovery
|
||||||
service_discovery = ServiceDiscovery()
|
service_discovery = ServiceDiscovery()
|
||||||
|
|
||||||
# CORS middleware - FIXED: Use the parsed list property
|
# CORS middleware - Add first
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=settings.CORS_ORIGINS_LIST,
|
allow_origins=settings.CORS_ORIGINS_LIST,
|
||||||
@@ -49,10 +49,10 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Custom middleware
|
# Custom middleware - Add in correct order (outer to inner)
|
||||||
app.add_middleware(auth_middleware)
|
app.add_middleware(LoggingMiddleware)
|
||||||
app.add_middleware(logging_middleware)
|
app.add_middleware(RateLimitMiddleware, calls_per_minute=60)
|
||||||
app.add_middleware(rate_limit_middleware)
|
app.add_middleware(AuthMiddleware)
|
||||||
|
|
||||||
# Include routers
|
# Include routers
|
||||||
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
|
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
|
||||||
@@ -88,43 +88,17 @@ async def shutdown_event():
|
|||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
"""Health check endpoint"""
|
"""Health check endpoint"""
|
||||||
healthy_services = await service_discovery.get_healthy_services()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"service": "gateway",
|
"service": "api-gateway",
|
||||||
"version": "1.0.0",
|
"version": "1.0.0",
|
||||||
"healthy_services": healthy_services,
|
|
||||||
"total_services": len(settings.SERVICES),
|
|
||||||
"timestamp": time.time()
|
"timestamp": time.time()
|
||||||
}
|
}
|
||||||
|
|
||||||
@app.get("/metrics")
|
@app.get("/metrics")
|
||||||
async def get_metrics():
|
async def metrics():
|
||||||
"""Get basic metrics"""
|
"""Metrics endpoint for monitoring"""
|
||||||
return {
|
return {"metrics": "enabled"}
|
||||||
"service": "gateway",
|
|
||||||
"uptime": time.time() - app.state.start_time if hasattr(app.state, 'start_time') else 0,
|
|
||||||
"healthy_services": await service_discovery.get_healthy_services()
|
|
||||||
}
|
|
||||||
|
|
||||||
@app.exception_handler(HTTPException)
|
|
||||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
|
||||||
"""Handle HTTP exceptions"""
|
|
||||||
logger.error(f"HTTP {exc.status_code}: {exc.detail}")
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=exc.status_code,
|
|
||||||
content={"detail": exc.detail, "service": "gateway"}
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.exception_handler(Exception)
|
|
||||||
async def general_exception_handler(request: Request, exc: Exception):
|
|
||||||
"""Handle general exceptions"""
|
|
||||||
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=500,
|
|
||||||
content={"detail": "Internal server error", "service": "gateway"}
|
|
||||||
)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ Authentication middleware for gateway
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from fastapi import Request, HTTPException
|
from fastapi import Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import Response
|
||||||
import httpx
|
import httpx
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -28,74 +30,77 @@ PUBLIC_ROUTES = [
|
|||||||
"/api/v1/auth/refresh"
|
"/api/v1/auth/refresh"
|
||||||
]
|
]
|
||||||
|
|
||||||
async def auth_middleware(request: Request, call_next):
|
class AuthMiddleware(BaseHTTPMiddleware):
|
||||||
"""Authentication middleware"""
|
"""Authentication middleware class"""
|
||||||
|
|
||||||
# Check if route requires authentication
|
async def dispatch(self, request: Request, call_next) -> Response:
|
||||||
if _is_public_route(request.url.path):
|
"""Process request with authentication"""
|
||||||
return await call_next(request)
|
|
||||||
|
|
||||||
# Get token from header
|
|
||||||
token = _extract_token(request)
|
|
||||||
if not token:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=401,
|
|
||||||
content={"detail": "Authentication required"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify token
|
|
||||||
try:
|
|
||||||
# First try to verify token locally
|
|
||||||
payload = jwt_handler.verify_token(token)
|
|
||||||
|
|
||||||
if payload:
|
# Check if route requires authentication
|
||||||
# Add user info to request state
|
if self._is_public_route(request.url.path):
|
||||||
request.state.user = payload
|
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
else:
|
|
||||||
# Token invalid or expired, verify with auth service
|
# Get token from header
|
||||||
user_info = await _verify_with_auth_service(token)
|
token = self._extract_token(request)
|
||||||
if user_info:
|
if not token:
|
||||||
request.state.user = user_info
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={"detail": "Authentication required"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify token
|
||||||
|
try:
|
||||||
|
# First try to verify token locally
|
||||||
|
payload = jwt_handler.verify_token(token)
|
||||||
|
|
||||||
|
if payload:
|
||||||
|
# Add user info to request state
|
||||||
|
request.state.user = payload
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
else:
|
else:
|
||||||
return JSONResponse(
|
# Token invalid or expired, verify with auth service
|
||||||
status_code=401,
|
user_info = await self._verify_with_auth_service(token)
|
||||||
content={"detail": "Invalid or expired token"}
|
if user_info:
|
||||||
|
request.state.user = user_info
|
||||||
|
return await call_next(request)
|
||||||
|
else:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={"detail": "Invalid or expired token"}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Authentication error: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={"detail": "Authentication failed"}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_public_route(self, path: str) -> bool:
|
||||||
|
"""Check if route is public"""
|
||||||
|
return any(path.startswith(route) for route in PUBLIC_ROUTES)
|
||||||
|
|
||||||
|
def _extract_token(self, request: Request) -> Optional[str]:
|
||||||
|
"""Extract JWT token from request"""
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
if auth_header and auth_header.startswith("Bearer "):
|
||||||
|
return auth_header.split(" ")[1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _verify_with_auth_service(self, token: str) -> Optional[dict]:
|
||||||
|
"""Verify token with auth service"""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{settings.AUTH_SERVICE_URL}/verify",
|
||||||
|
headers={"Authorization": f"Bearer {token}"}
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
if response.status_code == 200:
|
||||||
logger.error(f"Authentication error: {e}")
|
return response.json()
|
||||||
return JSONResponse(
|
else:
|
||||||
status_code=401,
|
return None
|
||||||
content={"detail": "Authentication failed"}
|
|
||||||
)
|
except Exception as e:
|
||||||
|
logger.error(f"Auth service verification failed: {e}")
|
||||||
def _is_public_route(path: str) -> bool:
|
return None
|
||||||
"""Check if route is public"""
|
|
||||||
return any(path.startswith(route) for route in PUBLIC_ROUTES)
|
|
||||||
|
|
||||||
def _extract_token(request: Request) -> Optional[str]:
|
|
||||||
"""Extract JWT token from request"""
|
|
||||||
auth_header = request.headers.get("Authorization")
|
|
||||||
if auth_header and auth_header.startswith("Bearer "):
|
|
||||||
return auth_header.split(" ")[1]
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _verify_with_auth_service(token: str) -> Optional[dict]:
|
|
||||||
"""Verify token with auth service"""
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
||||||
response = await client.post(
|
|
||||||
f"{settings.AUTH_SERVICE_URL}/verify",
|
|
||||||
headers={"Authorization": f"Bearer {token}"}
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
return response.json()
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Auth service verification failed: {e}")
|
|
||||||
return None
|
|
||||||
|
|||||||
@@ -5,44 +5,53 @@ Logging middleware for gateway
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
import json
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import Response
|
||||||
|
import uuid
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def logging_middleware(request: Request, call_next):
|
class LoggingMiddleware(BaseHTTPMiddleware):
|
||||||
"""Logging middleware"""
|
"""Logging middleware class"""
|
||||||
|
|
||||||
start_time = time.time()
|
async def dispatch(self, request: Request, call_next) -> Response:
|
||||||
|
"""Process request with logging"""
|
||||||
# Log request
|
|
||||||
logger.info(
|
start_time = time.time()
|
||||||
f"Request: {request.method} {request.url.path}",
|
|
||||||
extra={
|
# Generate request ID
|
||||||
"method": request.method,
|
request_id = str(uuid.uuid4())
|
||||||
"url": request.url.path,
|
request.state.request_id = request_id
|
||||||
"query_params": str(request.query_params),
|
|
||||||
"client_host": request.client.host,
|
# Log request
|
||||||
"user_agent": request.headers.get("user-agent", ""),
|
logger.info(
|
||||||
"request_id": getattr(request.state, 'request_id', None)
|
f"Request: {request.method} {request.url.path}",
|
||||||
}
|
extra={
|
||||||
)
|
"method": request.method,
|
||||||
|
"url": request.url.path,
|
||||||
# Process request
|
"query_params": str(request.query_params),
|
||||||
response = await call_next(request)
|
"client_host": request.client.host if request.client else "unknown",
|
||||||
|
"user_agent": request.headers.get("user-agent", ""),
|
||||||
# Calculate duration
|
"request_id": request_id
|
||||||
duration = time.time() - start_time
|
}
|
||||||
|
)
|
||||||
# Log response
|
|
||||||
logger.info(
|
# Process request
|
||||||
f"Response: {response.status_code} in {duration:.3f}s",
|
response = await call_next(request)
|
||||||
extra={
|
|
||||||
"status_code": response.status_code,
|
# Calculate duration
|
||||||
"duration": duration,
|
duration = time.time() - start_time
|
||||||
"method": request.method,
|
|
||||||
"url": request.url.path,
|
# Log response
|
||||||
"request_id": getattr(request.state, 'request_id', None)
|
logger.info(
|
||||||
}
|
f"Response: {response.status_code} in {duration:.3f}s",
|
||||||
)
|
extra={
|
||||||
|
"status_code": response.status_code,
|
||||||
return response
|
"duration": duration,
|
||||||
|
"method": request.method,
|
||||||
|
"url": request.url.path,
|
||||||
|
"request_id": request_id
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
@@ -3,83 +3,91 @@ Rate limiting middleware for gateway
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from fastapi import Request, HTTPException
|
import time
|
||||||
|
from fastapi import Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
import redis.asyncio as redis
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from datetime import datetime, timedelta
|
from starlette.responses import Response
|
||||||
import hashlib
|
from typing import Dict, Optional
|
||||||
|
import asyncio
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Redis client for rate limiting
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
redis_client = redis.from_url(settings.REDIS_URL)
|
"""Rate limiting middleware class"""
|
||||||
|
|
||||||
async def rate_limit_middleware(request: Request, call_next):
|
|
||||||
"""Rate limiting middleware"""
|
|
||||||
|
|
||||||
# Skip rate limiting for health checks
|
def __init__(self, app, calls_per_minute: int = 60):
|
||||||
if request.url.path in ["/health", "/metrics"]:
|
super().__init__(app)
|
||||||
|
self.calls_per_minute = calls_per_minute
|
||||||
|
self.requests: Dict[str, list] = {}
|
||||||
|
self._cleanup_task = None
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response:
|
||||||
|
"""Process request with rate limiting"""
|
||||||
|
|
||||||
|
# Skip rate limiting for health checks
|
||||||
|
if request.url.path in ["/health", "/metrics"]:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Get client identifier
|
||||||
|
client_id = self._get_client_id(request)
|
||||||
|
|
||||||
|
# Check rate limit
|
||||||
|
if self._is_rate_limited(client_id):
|
||||||
|
logger.warning(f"Rate limit exceeded for client: {client_id}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={"detail": "Rate limit exceeded"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record request
|
||||||
|
self._record_request(client_id)
|
||||||
|
|
||||||
|
# Process request
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
# Get client identifier (IP address or user ID)
|
def _get_client_id(self, request: Request) -> str:
|
||||||
client_id = _get_client_id(request)
|
"""Get client identifier"""
|
||||||
|
# Try to get user ID from state (if authenticated)
|
||||||
# Check rate limit
|
if hasattr(request.state, 'user') and request.state.user:
|
||||||
if await _is_rate_limited(client_id):
|
return f"user:{request.state.user.get('user_id', 'unknown')}"
|
||||||
return JSONResponse(
|
|
||||||
status_code=429,
|
|
||||||
content={
|
|
||||||
"detail": "Rate limit exceeded",
|
|
||||||
"retry_after": settings.RATE_LIMIT_WINDOW
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process request
|
|
||||||
response = await call_next(request)
|
|
||||||
|
|
||||||
# Update rate limit counter
|
|
||||||
await _update_rate_limit(client_id)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def _get_client_id(request: Request) -> str:
|
|
||||||
"""Get client identifier for rate limiting"""
|
|
||||||
# Use user ID if authenticated, otherwise use IP
|
|
||||||
if hasattr(request.state, 'user') and request.state.user:
|
|
||||||
return f"user:{request.state.user.get('user_id', 'unknown')}"
|
|
||||||
else:
|
|
||||||
# Hash IP address for privacy
|
|
||||||
ip = request.client.host
|
|
||||||
return f"ip:{hashlib.md5(ip.encode()).hexdigest()}"
|
|
||||||
|
|
||||||
async def _is_rate_limited(client_id: str) -> bool:
|
|
||||||
"""Check if client is rate limited"""
|
|
||||||
try:
|
|
||||||
key = f"rate_limit:{client_id}"
|
|
||||||
current_count = await redis_client.get(key)
|
|
||||||
|
|
||||||
if current_count is None:
|
# Fall back to IP address
|
||||||
|
return f"ip:{request.client.host if request.client else 'unknown'}"
|
||||||
|
|
||||||
|
def _is_rate_limited(self, client_id: str) -> bool:
|
||||||
|
"""Check if client is rate limited"""
|
||||||
|
now = time.time()
|
||||||
|
minute_ago = now - 60
|
||||||
|
|
||||||
|
# Get recent requests for this client
|
||||||
|
if client_id not in self.requests:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return int(current_count) >= settings.RATE_LIMIT_REQUESTS
|
# Filter requests from last minute
|
||||||
|
recent_requests = [
|
||||||
|
req_time for req_time in self.requests[client_id]
|
||||||
|
if req_time > minute_ago
|
||||||
|
]
|
||||||
|
|
||||||
except Exception as e:
|
# Update the list
|
||||||
logger.error(f"Rate limit check failed: {e}")
|
self.requests[client_id] = recent_requests
|
||||||
return False
|
|
||||||
|
|
||||||
async def _update_rate_limit(client_id: str):
|
|
||||||
"""Update rate limit counter"""
|
|
||||||
try:
|
|
||||||
key = f"rate_limit:{client_id}"
|
|
||||||
|
|
||||||
# Increment counter
|
# Check if limit exceeded
|
||||||
current_count = await redis_client.incr(key)
|
return len(recent_requests) >= self.calls_per_minute
|
||||||
|
|
||||||
|
def _record_request(self, client_id: str):
|
||||||
|
"""Record a request for rate limiting"""
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
# Set TTL on first request
|
if client_id not in self.requests:
|
||||||
if current_count == 1:
|
self.requests[client_id] = []
|
||||||
await redis_client.expire(key, settings.RATE_LIMIT_WINDOW)
|
|
||||||
|
self.requests[client_id].append(now)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Rate limit update failed: {e}")
|
# Keep only last minute of requests
|
||||||
|
minute_ago = now - 60
|
||||||
|
self.requests[client_id] = [
|
||||||
|
req_time for req_time in self.requests[client_id]
|
||||||
|
if req_time > minute_ago
|
||||||
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user