REFACTOR API gateway
This commit is contained in:
@@ -17,7 +17,7 @@ from app.core.service_discovery import ServiceDiscovery
|
||||
from app.middleware.auth import AuthMiddleware
|
||||
from app.middleware.logging import LoggingMiddleware
|
||||
from app.middleware.rate_limit import RateLimitMiddleware
|
||||
from app.routes import auth, training, forecasting, data, tenant, notification, nominatim, user
|
||||
from app.routes import auth, tenant, notification, nominatim, user
|
||||
from shared.monitoring.logging import setup_logging
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
|
||||
@@ -56,10 +56,7 @@ app.add_middleware(AuthMiddleware)
|
||||
|
||||
# Include routers
|
||||
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
|
||||
app.include_router(auth.router, prefix="/api/v1/user", tags=["user"])
|
||||
app.include_router(training.router, prefix="/api/v1/training", tags=["training"])
|
||||
app.include_router(forecasting.router, prefix="/api/v1/forecasting", tags=["forecasting"])
|
||||
app.include_router(data.router, prefix="/api/v1/data", tags=["data"])
|
||||
app.include_router(user.router, prefix="/api/v1/user", tags=["user"])
|
||||
app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"])
|
||||
app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"])
|
||||
app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"])
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# gateway/app/middleware/auth.py - IMPROVED VERSION
|
||||
# gateway/app/middleware/auth.py
|
||||
"""
|
||||
Enhanced Authentication Middleware for API Gateway
|
||||
Implements proper token validation and tenant context extraction
|
||||
Enhanced Authentication Middleware for API Gateway with Tenant Access Control
|
||||
"""
|
||||
|
||||
import structlog
|
||||
@@ -9,12 +8,11 @@ from fastapi import Request, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
import httpx
|
||||
from typing import Optional, Dict, Any
|
||||
import asyncio
|
||||
|
||||
from app.core.config import settings
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -32,20 +30,12 @@ PUBLIC_ROUTES = [
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/refresh",
|
||||
"/api/v1/auth/verify",
|
||||
"/api/v1/tenant/register",
|
||||
"/api/v1/nominatim/search"
|
||||
]
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Enhanced Authentication Middleware following microservices best practices
|
||||
|
||||
Responsibilities:
|
||||
1. Token validation (local first, then auth service)
|
||||
2. User context injection
|
||||
3. Tenant context extraction (per request)
|
||||
4. Rate limiting enforcement
|
||||
5. Request routing decisions
|
||||
Enhanced Authentication Middleware with Tenant Access Control
|
||||
"""
|
||||
|
||||
def __init__(self, app, redis_client=None):
|
||||
@@ -53,17 +43,17 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
self.redis_client = redis_client # For caching and rate limiting
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process request with enhanced authentication"""
|
||||
|
||||
"""Process request with enhanced authentication and tenant access control"""
|
||||
|
||||
# Skip authentication for OPTIONS requests (CORS preflight)
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
# Skip authentication for public routes
|
||||
if self._is_public_route(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract and validate JWT token
|
||||
|
||||
# ✅ STEP 1: Extract and validate JWT token
|
||||
token = self._extract_token(request)
|
||||
if not token:
|
||||
logger.warning(f"Missing token for protected route: {request.url.path}")
|
||||
@@ -71,8 +61,9 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
status_code=401,
|
||||
content={"detail": "Authentication required"}
|
||||
)
|
||||
|
||||
# Verify token and get user context
|
||||
|
||||
# ✅ STEP 2: Verify token and get user context
|
||||
# Pass self.redis_client to _verify_token to enable caching
|
||||
user_context = await self._verify_token(token)
|
||||
if not user_context:
|
||||
logger.warning(f"Invalid token for route: {request.url.path}")
|
||||
@@ -80,30 +71,50 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
status_code=401,
|
||||
content={"detail": "Invalid or expired token"}
|
||||
)
|
||||
|
||||
# Extract tenant context from request (not from JWT)
|
||||
tenant_id = self._extract_tenant_from_request(request)
|
||||
|
||||
# Verify user has access to tenant (if tenant_id provided)
|
||||
if tenant_id:
|
||||
has_access = await self._verify_tenant_access(user_context["user_id"], tenant_id)
|
||||
|
||||
# ✅ STEP 3: Extract tenant context from URL using shared utility
|
||||
tenant_id = extract_tenant_id_from_path(request.url.path)
|
||||
|
||||
# ✅ STEP 4: Verify tenant access if this is a tenant-scoped route
|
||||
if tenant_id and is_tenant_scoped_path(request.url.path):
|
||||
# Use TenantAccessManager for gateway-level verification with caching
|
||||
# Ensure tenant_access_manager uses the redis_client from the middleware
|
||||
if self.redis_client and tenant_access_manager.redis_client is None:
|
||||
tenant_access_manager.redis_client = self.redis_client
|
||||
|
||||
has_access = await tenant_access_manager.verify_basic_tenant_access( # Corrected method call
|
||||
user_context["user_id"],
|
||||
tenant_id
|
||||
)
|
||||
|
||||
if not has_access:
|
||||
logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Access denied to tenant"}
|
||||
content={"detail": f"Access denied to tenant {tenant_id}"}
|
||||
)
|
||||
|
||||
# Set tenant context in request state
|
||||
request.state.tenant_id = tenant_id
|
||||
|
||||
# Inject user context into request
|
||||
request.state.tenant_verified = True
|
||||
|
||||
logger.debug(f"Tenant access verified",
|
||||
user_id=user_context["user_id"],
|
||||
tenant_id=tenant_id,
|
||||
path=request.url.path)
|
||||
|
||||
# ✅ STEP 5: Inject user context into request
|
||||
request.state.user = user_context
|
||||
request.state.authenticated = True
|
||||
|
||||
# Add user context to forwarded requests
|
||||
self._inject_auth_headers(request, user_context, tenant_id)
|
||||
|
||||
logger.debug(f"Authenticated request: {user_context['email']} -> {request.url.path}")
|
||||
|
||||
|
||||
# ✅ STEP 6: Add context headers for downstream services
|
||||
self._inject_context_headers(request, user_context, tenant_id)
|
||||
|
||||
logger.debug(f"Authenticated request",
|
||||
user_email=user_context['email'],
|
||||
tenant_id=tenant_id,
|
||||
path=request.url.path)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def _is_public_route(self, path: str) -> bool:
|
||||
@@ -117,46 +128,10 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
return auth_header.split(" ")[1]
|
||||
return None
|
||||
|
||||
def _extract_tenant_from_request(self, request: Request) -> Optional[str]:
|
||||
"""
|
||||
Extract tenant ID from request (NOT from JWT token)
|
||||
|
||||
Priority order:
|
||||
1. X-Tenant-ID header
|
||||
2. tenant_id query parameter
|
||||
3. tenant_id in request path
|
||||
"""
|
||||
# Method 1: Header
|
||||
tenant_id = request.headers.get("X-Tenant-ID")
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Method 2: Query parameter
|
||||
tenant_id = request.query_params.get("tenant_id")
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Method 3: Path parameter (extract from URLs like /api/v1/tenants/{tenant_id}/...)
|
||||
path_parts = request.url.path.split("/")
|
||||
if "tenants" in path_parts:
|
||||
try:
|
||||
tenant_index = path_parts.index("tenants")
|
||||
if tenant_index + 1 < len(path_parts):
|
||||
return path_parts[tenant_index + 1]
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def _verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verify JWT token with fallback strategy:
|
||||
1. Local validation (fast)
|
||||
2. Auth service validation (authoritative)
|
||||
3. Cache valid tokens to reduce auth service calls
|
||||
"""
|
||||
"""Verify JWT token with fallback strategy"""
|
||||
|
||||
# Step 1: Try local JWT validation first (fast)
|
||||
# Try local JWT validation first (fast)
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if payload and self._validate_token_payload(payload):
|
||||
@@ -165,7 +140,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
except Exception as e:
|
||||
logger.debug(f"Local token validation failed: {e}")
|
||||
|
||||
# Step 2: Check cache for recently validated tokens
|
||||
# Check cache for recently validated tokens
|
||||
if self.redis_client:
|
||||
try:
|
||||
cached_user = await self._get_cached_user(token)
|
||||
@@ -175,7 +150,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache lookup failed: {e}")
|
||||
|
||||
# Step 3: Verify with auth service (authoritative)
|
||||
# Verify with auth service (authoritative)
|
||||
try:
|
||||
user_context = await self._verify_with_auth_service(token)
|
||||
if user_context:
|
||||
@@ -186,7 +161,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
return user_context
|
||||
except Exception as e:
|
||||
logger.error(f"Auth service validation failed: {e}")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
def _validate_token_payload(self, payload: Dict[str, Any]) -> bool:
|
||||
@@ -197,6 +172,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Verify token with auth service"""
|
||||
try:
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
response = await client.post(
|
||||
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
|
||||
@@ -209,35 +185,25 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
logger.warning(f"Auth service returned {response.status_code}")
|
||||
return None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Auth service timeout")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Auth service error: {e}")
|
||||
return None
|
||||
|
||||
async def _verify_tenant_access(self, user_id: str, tenant_id: str) -> bool:
|
||||
"""Verify user has access to specific tenant"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/access/{user_id}"
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant access verification failed: {e}")
|
||||
return False
|
||||
|
||||
async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user context from cache"""
|
||||
if not self.redis_client:
|
||||
return None
|
||||
|
||||
cache_key = f"auth:token:{hash(token)}"
|
||||
cached_data = await self.redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
import json
|
||||
return json.loads(cached_data)
|
||||
try:
|
||||
cached_data = await self.redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
import json
|
||||
if isinstance(cached_data, bytes):
|
||||
cached_data = cached_data.decode()
|
||||
return json.loads(cached_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache get failed: {e}")
|
||||
return None
|
||||
|
||||
async def _cache_user(self, token: str, user_context: Dict[str, Any], ttl: int = 300):
|
||||
@@ -246,45 +212,45 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
return
|
||||
|
||||
cache_key = f"auth:token:{hash(token)}"
|
||||
import json
|
||||
await self.redis_client.setex(cache_key, ttl, json.dumps(user_context))
|
||||
try:
|
||||
import json
|
||||
await self.redis_client.setex(cache_key, ttl, json.dumps(user_context))
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache set failed: {e}")
|
||||
|
||||
def _inject_auth_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str]):
|
||||
"""
|
||||
Inject authentication headers for downstream services
|
||||
|
||||
This allows services to work both:
|
||||
1. Behind the gateway (using request.state)
|
||||
2. Called directly (using headers) for development/testing
|
||||
"""
|
||||
# Remove any existing auth headers to prevent spoofing
|
||||
headers_to_remove = [
|
||||
"x-user-id", "x-user-email", "x-user-role",
|
||||
"x-tenant-id", "x-user-permissions", "x-authenticated"
|
||||
def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str]):
|
||||
"""Inject authentication and tenant headers for downstream services"""
|
||||
|
||||
# Remove any existing auth headers to prevent spoofing
|
||||
headers_to_remove = [
|
||||
"x-user-id", "x-user-email", "x-user-role",
|
||||
"x-tenant-id", "x-tenant-verified", "x-authenticated"
|
||||
]
|
||||
|
||||
for header in headers_to_remove:
|
||||
request.headers.__dict__["_list"] = [
|
||||
(k, v) for k, v in request.headers.raw
|
||||
if k.lower() != header.lower()
|
||||
]
|
||||
|
||||
for header in headers_to_remove:
|
||||
request.headers.__dict__["_list"] = [
|
||||
(k, v) for k, v in request.headers.raw
|
||||
if k.lower() != header.lower()
|
||||
]
|
||||
|
||||
# Inject new headers
|
||||
new_headers = [
|
||||
(b"x-authenticated", b"true"),
|
||||
(b"x-user-id", str(user_context.get("user_id", "")).encode()),
|
||||
(b"x-user-email", str(user_context.get("email", "")).encode()),
|
||||
(b"x-user-role", str(user_context.get("role", "user")).encode()),
|
||||
]
|
||||
|
||||
if tenant_id:
|
||||
new_headers.append((b"x-tenant-id", tenant_id.encode()))
|
||||
|
||||
permissions = user_context.get("permissions", [])
|
||||
if permissions:
|
||||
new_headers.append((b"x-user-permissions", ",".join(permissions).encode()))
|
||||
|
||||
# Add headers to request
|
||||
request.headers.__dict__["_list"].extend(new_headers)
|
||||
|
||||
logger.debug(f"Injected auth headers for user {user_context.get('email')}")
|
||||
|
||||
# Inject new headers
|
||||
new_headers = [
|
||||
(b"x-authenticated", b"true"),
|
||||
(b"x-user-id", str(user_context.get("user_id", "")).encode()),
|
||||
(b"x-user-email", str(user_context.get("email", "")).encode()),
|
||||
(b"x-user-role", str(user_context.get("role", "user")).encode()),
|
||||
]
|
||||
|
||||
# Add tenant context if verified
|
||||
if tenant_id:
|
||||
new_headers.extend([
|
||||
(b"x-tenant-id", tenant_id.encode()),
|
||||
(b"x-tenant-verified", b"true")
|
||||
])
|
||||
|
||||
# Add headers to request
|
||||
request.headers.__dict__["_list"].extend(new_headers)
|
||||
|
||||
logger.debug(f"Injected context headers",
|
||||
user_id=user_context.get("user_id"),
|
||||
tenant_id=tenant_id)
|
||||
@@ -1,89 +0,0 @@
|
||||
"""Data service routes for API Gateway - Authentication handled by gateway middleware"""
|
||||
|
||||
from fastapi import APIRouter, Request, Response, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
@router.api_route("/sales/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_sales(request: Request, path: str):
|
||||
"""Proxy sales data requests to data service"""
|
||||
return await _proxy_request(request, f"/api/v1/sales/{path}")
|
||||
|
||||
@router.api_route("/weather/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_weather(request: Request, path: str):
|
||||
"""Proxy weather requests to data service"""
|
||||
return await _proxy_request(request, f"/api/v1/weather/{path}")
|
||||
|
||||
@router.api_route("/traffic/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_traffic(request: Request, path: str):
|
||||
"""Proxy traffic requests to data service"""
|
||||
return await _proxy_request(request, f"/api/v1/traffic/{path}")
|
||||
|
||||
async def _proxy_request(request: Request, target_path: str):
|
||||
"""Proxy request to data service with user context"""
|
||||
|
||||
# Handle OPTIONS requests directly for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
url = f"{settings.DATA_SERVICE_URL}{target_path}"
|
||||
|
||||
# Forward headers BUT add user context from gateway auth
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None) # Remove host header
|
||||
|
||||
# ✅ ADD USER CONTEXT FROM GATEWAY AUTHENTICATION
|
||||
# Gateway middleware already verified the token and added user to request.state
|
||||
if hasattr(request.state, 'user'):
|
||||
headers["X-User-ID"] = str(request.state.user.get("user_id"))
|
||||
headers["X-User-Email"] = request.state.user.get("email", "")
|
||||
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
|
||||
headers["X-User-Roles"] = ",".join(request.state.user.get("roles", []))
|
||||
|
||||
# Get request body if present
|
||||
body = None
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
params=request.query_params,
|
||||
headers=headers,
|
||||
content=body
|
||||
)
|
||||
|
||||
# Return streaming response for large payloads
|
||||
if int(response.headers.get("content-length", 0)) > 1024:
|
||||
return StreamingResponse(
|
||||
iter([response.content]),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.headers.get("content-type")
|
||||
)
|
||||
else:
|
||||
return response.json() if response.headers.get("content-type", "").startswith("application/json") else response.content
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error("Data service request failed", error=str(e))
|
||||
raise HTTPException(status_code=503, detail="Data service unavailable")
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error in data proxy", error=str(e))
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@@ -1,74 +0,0 @@
|
||||
# ================================================================
|
||||
# Gateway Integration: Update gateway/app/routes/forecasting.py
|
||||
# ================================================================
|
||||
"""Forecasting service routes for API Gateway"""
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
@router.api_route("/forecasts/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_forecasts(request: Request, path: str):
|
||||
"""Proxy forecast requests to forecasting service"""
|
||||
return await _proxy_request(request, f"/api/v1/forecasts/{path}")
|
||||
|
||||
@router.api_route("/predictions/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_predictions(request: Request, path: str):
|
||||
"""Proxy prediction requests to forecasting service"""
|
||||
return await _proxy_request(request, f"/api/v1/predictions/{path}")
|
||||
|
||||
async def _proxy_request(request: Request, target_path: str):
|
||||
"""Proxy request to forecasting service with user context"""
|
||||
|
||||
# Handle OPTIONS requests directly for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
url = f"{settings.FORECASTING_SERVICE_URL}{target_path}"
|
||||
|
||||
# Forward headers and add user context
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
|
||||
# Add user context from gateway authentication
|
||||
if hasattr(request.state, 'user'):
|
||||
headers["X-User-ID"] = str(request.state.user.get("user_id"))
|
||||
headers["X-User-Email"] = request.state.user.get("email", "")
|
||||
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
|
||||
headers["X-User-Roles"] = ",".join(request.state.user.get("roles", []))
|
||||
|
||||
# Get request body if present
|
||||
body = None
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
params=request.query_params
|
||||
)
|
||||
|
||||
# Return response
|
||||
return response.json() if response.headers.get("content-type", "").startswith("application/json") else response.text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error proxying to forecasting service: {e}")
|
||||
raise
|
||||
@@ -1,156 +1,200 @@
|
||||
# gateway/app/routes/tenant.py - COMPLETELY UPDATED
|
||||
"""
|
||||
Tenant routes for gateway - FIXED VERSION
|
||||
Tenant routes for API Gateway - Handles all tenant-scoped endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi import APIRouter, Request, Response, HTTPException, Path
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# ================================================================
|
||||
# TENANT MANAGEMENT ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.post("/register")
|
||||
async def create_tenant(request: Request):
|
||||
"""Proxy tenant creation to tenant service"""
|
||||
try:
|
||||
body = await request.body()
|
||||
|
||||
# ✅ FIX: Forward all headers AND add user context from gateway auth
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None) # Remove host header
|
||||
|
||||
# ✅ ADD USER CONTEXT FROM GATEWAY AUTHENTICATION
|
||||
# Gateway middleware already verified the token and added user to request.state
|
||||
if hasattr(request.state, 'user'):
|
||||
headers["X-User-ID"] = str(request.state.user.get("user_id"))
|
||||
headers["X-User-Email"] = request.state.user.get("email", "")
|
||||
headers["X-User-Role"] = request.state.user.get("role", "user")
|
||||
|
||||
# Add tenant ID if it exists
|
||||
if hasattr(request.state, 'tenant_id') and request.state.tenant_id:
|
||||
headers["X-Tenant-ID"] = str(request.state.tenant_id)
|
||||
elif request.state.user.get("tenant_id"):
|
||||
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
|
||||
|
||||
roles = request.state.user.get("roles", [])
|
||||
if roles:
|
||||
headers["X-User-Roles"] = ",".join(roles)
|
||||
|
||||
permissions = request.state.user.get("permissions", [])
|
||||
if permissions:
|
||||
headers["X-User-Permissions"] = ",".join(permissions)
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.post(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/register",
|
||||
content=body,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response.status_code,
|
||||
content=response.json()
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Tenant service unavailable: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Tenant service unavailable"
|
||||
)
|
||||
return await _proxy_to_tenant_service(request, "/api/v1/tenants/register")
|
||||
|
||||
@router.get("/")
|
||||
async def get_tenants(request: Request):
|
||||
"""Get tenants"""
|
||||
@router.get("/{tenant_id}")
|
||||
async def get_tenant(request: Request, tenant_id: str = Path(...)):
|
||||
"""Get specific tenant details"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}")
|
||||
|
||||
@router.put("/{tenant_id}")
|
||||
async def update_tenant(request: Request, tenant_id: str = Path(...)):
|
||||
"""Update tenant details"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}")
|
||||
|
||||
@router.get("/{tenant_id}/members")
|
||||
async def get_tenant_members(request: Request, tenant_id: str = Path(...)):
|
||||
"""Get tenant members"""
|
||||
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/members")
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED DATA SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/sales/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_sales(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant sales requests to data service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/sales/{path}".rstrip("/")
|
||||
return await _proxy_to_data_service(request, target_path)
|
||||
|
||||
@router.api_route("/{tenant_id}/weather/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_tenant_weather(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant weather requests to data service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/weather/{path}".rstrip("/")
|
||||
return await _proxy_to_data_service(request, target_path)
|
||||
|
||||
@router.api_route("/{tenant_id}/analytics/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_tenant_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant analytics requests to data service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/analytics/{path}".rstrip("/")
|
||||
return await _proxy_to_data_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED TRAINING SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/training/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_training(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant training requests to training service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/training/{path}".rstrip("/")
|
||||
return await _proxy_to_training_service(request, target_path)
|
||||
|
||||
@router.api_route("/{tenant_id}/models/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_models(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant model requests to training service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/models/{path}".rstrip("/")
|
||||
return await _proxy_to_training_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED FORECASTING SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/forecasts/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_forecasts(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant forecast requests to forecasting service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/forecasts/{path}".rstrip("/")
|
||||
return await _proxy_to_forecasting_service(request, target_path)
|
||||
|
||||
@router.api_route("/{tenant_id}/predictions/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_tenant_predictions(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant prediction requests to forecasting service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/predictions/{path}".rstrip("/")
|
||||
return await _proxy_to_forecasting_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# TENANT-SCOPED NOTIFICATION SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/notifications/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_notifications(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant notification requests to notification service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/notifications/{path}".rstrip("/")
|
||||
return await _proxy_to_notification_service(request, target_path)
|
||||
|
||||
# ================================================================
|
||||
# PROXY HELPER FUNCTIONS
|
||||
# ================================================================
|
||||
|
||||
async def _proxy_to_tenant_service(request: Request, target_path: str):
|
||||
"""Proxy request to tenant service"""
|
||||
return await _proxy_request(request, target_path, settings.TENANT_SERVICE_URL)
|
||||
|
||||
async def _proxy_to_data_service(request: Request, target_path: str):
|
||||
"""Proxy request to data service"""
|
||||
return await _proxy_request(request, target_path, settings.DATA_SERVICE_URL)
|
||||
|
||||
async def _proxy_to_training_service(request: Request, target_path: str):
|
||||
"""Proxy request to training service"""
|
||||
return await _proxy_request(request, target_path, settings.TRAINING_SERVICE_URL)
|
||||
|
||||
async def _proxy_to_forecasting_service(request: Request, target_path: str):
|
||||
"""Proxy request to forecasting service"""
|
||||
return await _proxy_request(request, target_path, settings.FORECASTING_SERVICE_URL)
|
||||
|
||||
async def _proxy_to_notification_service(request: Request, target_path: str):
|
||||
"""Proxy request to notification service"""
|
||||
return await _proxy_request(request, target_path, settings.NOTIFICATION_SERVICE_URL)
|
||||
|
||||
async def _proxy_request(request: Request, target_path: str, service_url: str):
|
||||
"""Generic proxy function with enhanced error handling"""
|
||||
|
||||
# Handle OPTIONS requests directly for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400"
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
# ✅ FIX: Same pattern for GET requests
|
||||
url = f"{service_url}{target_path}"
|
||||
|
||||
# Forward headers and add user/tenant context
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
|
||||
# Add user context from gateway auth
|
||||
if hasattr(request.state, 'user'):
|
||||
headers["X-User-ID"] = str(request.state.user.get("user_id"))
|
||||
headers["X-User-Email"] = request.state.user.get("email", "")
|
||||
headers["X-User-Role"] = request.state.user.get("role", "user")
|
||||
|
||||
if hasattr(request.state, 'tenant_id') and request.state.tenant_id:
|
||||
headers["X-Tenant-ID"] = str(request.state.tenant_id)
|
||||
elif request.state.user.get("tenant_id"):
|
||||
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
|
||||
|
||||
roles = request.state.user.get("roles", [])
|
||||
if roles:
|
||||
headers["X-User-Roles"] = ",".join(roles)
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response.status_code,
|
||||
content=response.json()
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Tenant service unavailable: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Tenant service unavailable"
|
||||
)
|
||||
|
||||
# ✅ ADD: Generic proxy function like the data service has
|
||||
async def _proxy_tenant_request(request: Request, target_path: str, method: str = None):
|
||||
"""Proxy request to tenant service with user context"""
|
||||
try:
|
||||
url = f"{settings.TENANT_SERVICE_URL}{target_path}"
|
||||
|
||||
# Forward headers with user context
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
|
||||
# Add user context from gateway authentication
|
||||
if hasattr(request.state, 'user'):
|
||||
headers["X-User-ID"] = str(request.state.user.get("user_id"))
|
||||
headers["X-User-Email"] = request.state.user.get("email", "")
|
||||
headers["X-User-Role"] = request.state.user.get("role", "user")
|
||||
|
||||
if hasattr(request.state, 'tenant_id') and request.state.tenant_id:
|
||||
headers["X-Tenant-ID"] = str(request.state.tenant_id)
|
||||
elif request.state.user.get("tenant_id"):
|
||||
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
|
||||
|
||||
roles = request.state.user.get("roles", [])
|
||||
if roles:
|
||||
headers["X-User-Roles"] = ",".join(roles)
|
||||
|
||||
# Get request body if present
|
||||
body = None
|
||||
request_method = method or request.method
|
||||
if request_method in ["POST", "PUT", "PATCH"]:
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
# Add query parameters
|
||||
params = dict(request.query_params)
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.request(
|
||||
method=request_method,
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
params=dict(request.query_params)
|
||||
params=params
|
||||
)
|
||||
|
||||
# Handle different response types
|
||||
if response.headers.get("content-type", "").startswith("application/json"):
|
||||
try:
|
||||
content = response.json()
|
||||
except:
|
||||
content = {"message": "Invalid JSON response from service"}
|
||||
else:
|
||||
content = response.text
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response.status_code,
|
||||
content=response.json()
|
||||
content=content
|
||||
)
|
||||
|
||||
except httpx.TimeoutError:
|
||||
logger.error(f"Timeout calling {service_url}{target_path}")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail=f"Service timeout"
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Tenant service unavailable: {e}")
|
||||
logger.error(f"Request error calling {service_url}{target_path}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Tenant service unavailable"
|
||||
detail=f"Service unavailable"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error proxying to {service_url}{target_path}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal gateway error"
|
||||
)
|
||||
@@ -1,100 +0,0 @@
|
||||
"""
|
||||
Training routes for gateway - FIXED VERSION
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, Query, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
async def _proxy_training_request(request: Request, target_path: str, method: str = None):
|
||||
"""Proxy request to training service with user context"""
|
||||
|
||||
# Handle OPTIONS requests directly for CORS
|
||||
if request.method == "OPTIONS":
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
|
||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
url = f"{settings.TRAINING_SERVICE_URL}{target_path}"
|
||||
|
||||
# Forward headers AND add user context from gateway auth
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None) # Remove host header
|
||||
|
||||
# ✅ ADD USER CONTEXT FROM GATEWAY AUTHENTICATION
|
||||
# Gateway middleware already verified the token and added user to request.state
|
||||
if hasattr(request.state, 'user'):
|
||||
headers["X-User-ID"] = str(request.state.user.get("user_id"))
|
||||
headers["X-User-Email"] = request.state.user.get("email", "")
|
||||
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
|
||||
headers["X-User-Roles"] = ",".join(request.state.user.get("roles", []))
|
||||
headers["X-User-Permissions"] = ",".join(request.state.user.get("permissions", []))
|
||||
|
||||
# Get request body if present
|
||||
body = None
|
||||
request_method = method or request.method
|
||||
if request_method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.request(
|
||||
method=request_method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
params=dict(request.query_params)
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response.status_code,
|
||||
content=response.json()
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Training service unavailable: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Training service unavailable"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Training service error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/status/{training_job_id}")
|
||||
async def get_training_status(training_job_id: str, request: Request):
|
||||
"""Get training job status"""
|
||||
return await _proxy_training_request(request, f"/training/status/{training_job_id}", "GET")
|
||||
|
||||
@router.get("/models")
|
||||
async def get_trained_models(request: Request):
|
||||
"""Get trained models"""
|
||||
return await _proxy_training_request(request, "/training/models", "GET")
|
||||
|
||||
@router.get("/jobs")
|
||||
async def get_training_jobs(
|
||||
request: Request,
|
||||
limit: Optional[int] = Query(10, ge=1, le=100),
|
||||
offset: Optional[int] = Query(0, ge=0)
|
||||
):
|
||||
"""Get training jobs"""
|
||||
return await _proxy_training_request(request, f"/training/jobs?limit={limit}&offset={offset}", "GET")
|
||||
|
||||
@router.post("/jobs")
|
||||
async def start_training_job(request: Request):
|
||||
"""Start a new training job - Proxy to training service"""
|
||||
return await _proxy_training_request(request, "/training/jobs", "POST")
|
||||
Reference in New Issue
Block a user