REFACTOR API gateway

This commit is contained in:
Urtzi Alfaro
2025-07-26 18:46:52 +02:00
parent e49893e10a
commit e4885db828
24 changed files with 1049 additions and 1080 deletions

View File

@@ -17,7 +17,7 @@ from app.core.service_discovery import ServiceDiscovery
from app.middleware.auth import AuthMiddleware
from app.middleware.logging import LoggingMiddleware
from app.middleware.rate_limit import RateLimitMiddleware
from app.routes import auth, training, forecasting, data, tenant, notification, nominatim, user
from app.routes import auth, tenant, notification, nominatim, user
from shared.monitoring.logging import setup_logging
from shared.monitoring.metrics import MetricsCollector
@@ -56,10 +56,7 @@ app.add_middleware(AuthMiddleware)
# Include routers
app.include_router(auth.router, prefix="/api/v1/auth", tags=["authentication"])
app.include_router(auth.router, prefix="/api/v1/user", tags=["user"])
app.include_router(training.router, prefix="/api/v1/training", tags=["training"])
app.include_router(forecasting.router, prefix="/api/v1/forecasting", tags=["forecasting"])
app.include_router(data.router, prefix="/api/v1/data", tags=["data"])
app.include_router(user.router, prefix="/api/v1/user", tags=["user"])
app.include_router(tenant.router, prefix="/api/v1/tenants", tags=["tenants"])
app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"])
app.include_router(nominatim.router, prefix="/api/v1/nominatim", tags=["location"])

View File

@@ -1,7 +1,6 @@
# gateway/app/middleware/auth.py - IMPROVED VERSION
# gateway/app/middleware/auth.py
"""
Enhanced Authentication Middleware for API Gateway
Implements proper token validation and tenant context extraction
Enhanced Authentication Middleware for API Gateway with Tenant Access Control
"""
import structlog
@@ -9,12 +8,11 @@ from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
import httpx
from typing import Optional, Dict, Any
import asyncio
from app.core.config import settings
from shared.auth.jwt_handler import JWTHandler
from shared.auth.tenant_access import tenant_access_manager, extract_tenant_id_from_path, is_tenant_scoped_path
logger = structlog.get_logger()
@@ -32,20 +30,12 @@ PUBLIC_ROUTES = [
"/api/v1/auth/register",
"/api/v1/auth/refresh",
"/api/v1/auth/verify",
"/api/v1/tenant/register",
"/api/v1/nominatim/search"
]
class AuthMiddleware(BaseHTTPMiddleware):
"""
Enhanced Authentication Middleware following microservices best practices
Responsibilities:
1. Token validation (local first, then auth service)
2. User context injection
3. Tenant context extraction (per request)
4. Rate limiting enforcement
5. Request routing decisions
Enhanced Authentication Middleware with Tenant Access Control
"""
def __init__(self, app, redis_client=None):
@@ -53,17 +43,17 @@ class AuthMiddleware(BaseHTTPMiddleware):
self.redis_client = redis_client # For caching and rate limiting
async def dispatch(self, request: Request, call_next) -> Response:
"""Process request with enhanced authentication"""
"""Process request with enhanced authentication and tenant access control"""
# Skip authentication for OPTIONS requests (CORS preflight)
if request.method == "OPTIONS":
return await call_next(request)
# Skip authentication for public routes
if self._is_public_route(request.url.path):
return await call_next(request)
# Extract and validate JWT token
# ✅ STEP 1: Extract and validate JWT token
token = self._extract_token(request)
if not token:
logger.warning(f"Missing token for protected route: {request.url.path}")
@@ -71,8 +61,9 @@ class AuthMiddleware(BaseHTTPMiddleware):
status_code=401,
content={"detail": "Authentication required"}
)
# Verify token and get user context
# ✅ STEP 2: Verify token and get user context
# Pass self.redis_client to _verify_token to enable caching
user_context = await self._verify_token(token)
if not user_context:
logger.warning(f"Invalid token for route: {request.url.path}")
@@ -80,30 +71,50 @@ class AuthMiddleware(BaseHTTPMiddleware):
status_code=401,
content={"detail": "Invalid or expired token"}
)
# Extract tenant context from request (not from JWT)
tenant_id = self._extract_tenant_from_request(request)
# Verify user has access to tenant (if tenant_id provided)
if tenant_id:
has_access = await self._verify_tenant_access(user_context["user_id"], tenant_id)
# ✅ STEP 3: Extract tenant context from URL using shared utility
tenant_id = extract_tenant_id_from_path(request.url.path)
# ✅ STEP 4: Verify tenant access if this is a tenant-scoped route
if tenant_id and is_tenant_scoped_path(request.url.path):
# Use TenantAccessManager for gateway-level verification with caching
# Ensure tenant_access_manager uses the redis_client from the middleware
if self.redis_client and tenant_access_manager.redis_client is None:
tenant_access_manager.redis_client = self.redis_client
has_access = await tenant_access_manager.verify_basic_tenant_access( # Corrected method call
user_context["user_id"],
tenant_id
)
if not has_access:
logger.warning(f"User {user_context['email']} denied access to tenant {tenant_id}")
return JSONResponse(
status_code=403,
content={"detail": "Access denied to tenant"}
content={"detail": f"Access denied to tenant {tenant_id}"}
)
# Set tenant context in request state
request.state.tenant_id = tenant_id
# Inject user context into request
request.state.tenant_verified = True
logger.debug(f"Tenant access verified",
user_id=user_context["user_id"],
tenant_id=tenant_id,
path=request.url.path)
# ✅ STEP 5: Inject user context into request
request.state.user = user_context
request.state.authenticated = True
# Add user context to forwarded requests
self._inject_auth_headers(request, user_context, tenant_id)
logger.debug(f"Authenticated request: {user_context['email']} -> {request.url.path}")
# ✅ STEP 6: Add context headers for downstream services
self._inject_context_headers(request, user_context, tenant_id)
logger.debug(f"Authenticated request",
user_email=user_context['email'],
tenant_id=tenant_id,
path=request.url.path)
return await call_next(request)
def _is_public_route(self, path: str) -> bool:
@@ -117,46 +128,10 @@ class AuthMiddleware(BaseHTTPMiddleware):
return auth_header.split(" ")[1]
return None
def _extract_tenant_from_request(self, request: Request) -> Optional[str]:
"""
Extract tenant ID from request (NOT from JWT token)
Priority order:
1. X-Tenant-ID header
2. tenant_id query parameter
3. tenant_id in request path
"""
# Method 1: Header
tenant_id = request.headers.get("X-Tenant-ID")
if tenant_id:
return tenant_id
# Method 2: Query parameter
tenant_id = request.query_params.get("tenant_id")
if tenant_id:
return tenant_id
# Method 3: Path parameter (extract from URLs like /api/v1/tenants/{tenant_id}/...)
path_parts = request.url.path.split("/")
if "tenants" in path_parts:
try:
tenant_index = path_parts.index("tenants")
if tenant_index + 1 < len(path_parts):
return path_parts[tenant_index + 1]
except (ValueError, IndexError):
pass
return None
async def _verify_token(self, token: str) -> Optional[Dict[str, Any]]:
"""
Verify JWT token with fallback strategy:
1. Local validation (fast)
2. Auth service validation (authoritative)
3. Cache valid tokens to reduce auth service calls
"""
"""Verify JWT token with fallback strategy"""
# Step 1: Try local JWT validation first (fast)
# Try local JWT validation first (fast)
try:
payload = jwt_handler.verify_token(token)
if payload and self._validate_token_payload(payload):
@@ -165,7 +140,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
except Exception as e:
logger.debug(f"Local token validation failed: {e}")
# Step 2: Check cache for recently validated tokens
# Check cache for recently validated tokens
if self.redis_client:
try:
cached_user = await self._get_cached_user(token)
@@ -175,7 +150,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
except Exception as e:
logger.warning(f"Cache lookup failed: {e}")
# Step 3: Verify with auth service (authoritative)
# Verify with auth service (authoritative)
try:
user_context = await self._verify_with_auth_service(token)
if user_context:
@@ -186,7 +161,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
return user_context
except Exception as e:
logger.error(f"Auth service validation failed: {e}")
return None
def _validate_token_payload(self, payload: Dict[str, Any]) -> bool:
@@ -197,6 +172,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
async def _verify_with_auth_service(self, token: str) -> Optional[Dict[str, Any]]:
"""Verify token with auth service"""
try:
import httpx
async with httpx.AsyncClient(timeout=3.0) as client:
response = await client.post(
f"{settings.AUTH_SERVICE_URL}/api/v1/auth/verify",
@@ -209,35 +185,25 @@ class AuthMiddleware(BaseHTTPMiddleware):
logger.warning(f"Auth service returned {response.status_code}")
return None
except asyncio.TimeoutError:
logger.error("Auth service timeout")
return None
except Exception as e:
logger.error(f"Auth service error: {e}")
return None
async def _verify_tenant_access(self, user_id: str, tenant_id: str) -> bool:
"""Verify user has access to specific tenant"""
try:
async with httpx.AsyncClient(timeout=3.0) as client:
response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/{tenant_id}/access/{user_id}"
)
return response.status_code == 200
except Exception as e:
logger.error(f"Tenant access verification failed: {e}")
return False
async def _get_cached_user(self, token: str) -> Optional[Dict[str, Any]]:
"""Get user context from cache"""
if not self.redis_client:
return None
cache_key = f"auth:token:{hash(token)}"
cached_data = await self.redis_client.get(cache_key)
if cached_data:
import json
return json.loads(cached_data)
try:
cached_data = await self.redis_client.get(cache_key)
if cached_data:
import json
if isinstance(cached_data, bytes):
cached_data = cached_data.decode()
return json.loads(cached_data)
except Exception as e:
logger.warning(f"Cache get failed: {e}")
return None
async def _cache_user(self, token: str, user_context: Dict[str, Any], ttl: int = 300):
@@ -246,45 +212,45 @@ class AuthMiddleware(BaseHTTPMiddleware):
return
cache_key = f"auth:token:{hash(token)}"
import json
await self.redis_client.setex(cache_key, ttl, json.dumps(user_context))
try:
import json
await self.redis_client.setex(cache_key, ttl, json.dumps(user_context))
except Exception as e:
logger.warning(f"Cache set failed: {e}")
def _inject_auth_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str]):
"""
Inject authentication headers for downstream services
This allows services to work both:
1. Behind the gateway (using request.state)
2. Called directly (using headers) for development/testing
"""
# Remove any existing auth headers to prevent spoofing
headers_to_remove = [
"x-user-id", "x-user-email", "x-user-role",
"x-tenant-id", "x-user-permissions", "x-authenticated"
def _inject_context_headers(self, request: Request, user_context: Dict[str, Any], tenant_id: Optional[str]):
"""Inject authentication and tenant headers for downstream services"""
# Remove any existing auth headers to prevent spoofing
headers_to_remove = [
"x-user-id", "x-user-email", "x-user-role",
"x-tenant-id", "x-tenant-verified", "x-authenticated"
]
for header in headers_to_remove:
request.headers.__dict__["_list"] = [
(k, v) for k, v in request.headers.raw
if k.lower() != header.lower()
]
for header in headers_to_remove:
request.headers.__dict__["_list"] = [
(k, v) for k, v in request.headers.raw
if k.lower() != header.lower()
]
# Inject new headers
new_headers = [
(b"x-authenticated", b"true"),
(b"x-user-id", str(user_context.get("user_id", "")).encode()),
(b"x-user-email", str(user_context.get("email", "")).encode()),
(b"x-user-role", str(user_context.get("role", "user")).encode()),
]
if tenant_id:
new_headers.append((b"x-tenant-id", tenant_id.encode()))
permissions = user_context.get("permissions", [])
if permissions:
new_headers.append((b"x-user-permissions", ",".join(permissions).encode()))
# Add headers to request
request.headers.__dict__["_list"].extend(new_headers)
logger.debug(f"Injected auth headers for user {user_context.get('email')}")
# Inject new headers
new_headers = [
(b"x-authenticated", b"true"),
(b"x-user-id", str(user_context.get("user_id", "")).encode()),
(b"x-user-email", str(user_context.get("email", "")).encode()),
(b"x-user-role", str(user_context.get("role", "user")).encode()),
]
# Add tenant context if verified
if tenant_id:
new_headers.extend([
(b"x-tenant-id", tenant_id.encode()),
(b"x-tenant-verified", b"true")
])
# Add headers to request
request.headers.__dict__["_list"].extend(new_headers)
logger.debug(f"Injected context headers",
user_id=user_context.get("user_id"),
tenant_id=tenant_id)

View File

@@ -1,89 +0,0 @@
"""Data service routes for API Gateway - Authentication handled by gateway middleware"""
from fastapi import APIRouter, Request, Response, HTTPException
from fastapi.responses import StreamingResponse
import httpx
import logging
from app.core.config import settings
logger = logging.getLogger(__name__)
router = APIRouter()
@router.api_route("/sales/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_sales(request: Request, path: str):
"""Proxy sales data requests to data service"""
return await _proxy_request(request, f"/api/v1/sales/{path}")
@router.api_route("/weather/{path:path}", methods=["GET", "POST", "OPTIONS"])
async def proxy_weather(request: Request, path: str):
"""Proxy weather requests to data service"""
return await _proxy_request(request, f"/api/v1/weather/{path}")
@router.api_route("/traffic/{path:path}", methods=["GET", "POST", "OPTIONS"])
async def proxy_traffic(request: Request, path: str):
"""Proxy traffic requests to data service"""
return await _proxy_request(request, f"/api/v1/traffic/{path}")
async def _proxy_request(request: Request, target_path: str):
"""Proxy request to data service with user context"""
# Handle OPTIONS requests directly for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
}
)
try:
url = f"{settings.DATA_SERVICE_URL}{target_path}"
# Forward headers BUT add user context from gateway auth
headers = dict(request.headers)
headers.pop("host", None) # Remove host header
# ✅ ADD USER CONTEXT FROM GATEWAY AUTHENTICATION
# Gateway middleware already verified the token and added user to request.state
if hasattr(request.state, 'user'):
headers["X-User-ID"] = str(request.state.user.get("user_id"))
headers["X-User-Email"] = request.state.user.get("email", "")
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
headers["X-User-Roles"] = ",".join(request.state.user.get("roles", []))
# Get request body if present
body = None
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.request(
method=request.method,
url=url,
params=request.query_params,
headers=headers,
content=body
)
# Return streaming response for large payloads
if int(response.headers.get("content-length", 0)) > 1024:
return StreamingResponse(
iter([response.content]),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type")
)
else:
return response.json() if response.headers.get("content-type", "").startswith("application/json") else response.content
except httpx.RequestError as e:
logger.error("Data service request failed", error=str(e))
raise HTTPException(status_code=503, detail="Data service unavailable")
except Exception as e:
logger.error("Unexpected error in data proxy", error=str(e))
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -1,74 +0,0 @@
# ================================================================
# Gateway Integration: Update gateway/app/routes/forecasting.py
# ================================================================
"""Forecasting service routes for API Gateway"""
from fastapi import APIRouter, Request
import httpx
import logging
from app.core.config import settings
logger = logging.getLogger(__name__)
router = APIRouter()
@router.api_route("/forecasts/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_forecasts(request: Request, path: str):
"""Proxy forecast requests to forecasting service"""
return await _proxy_request(request, f"/api/v1/forecasts/{path}")
@router.api_route("/predictions/{path:path}", methods=["GET", "POST", "OPTIONS"])
async def proxy_predictions(request: Request, path: str):
"""Proxy prediction requests to forecasting service"""
return await _proxy_request(request, f"/api/v1/predictions/{path}")
async def _proxy_request(request: Request, target_path: str):
"""Proxy request to forecasting service with user context"""
# Handle OPTIONS requests directly for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
}
)
try:
url = f"{settings.FORECASTING_SERVICE_URL}{target_path}"
# Forward headers and add user context
headers = dict(request.headers)
headers.pop("host", None)
# Add user context from gateway authentication
if hasattr(request.state, 'user'):
headers["X-User-ID"] = str(request.state.user.get("user_id"))
headers["X-User-Email"] = request.state.user.get("email", "")
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
headers["X-User-Roles"] = ",".join(request.state.user.get("roles", []))
# Get request body if present
body = None
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.request(
method=request.method,
url=url,
headers=headers,
content=body,
params=request.query_params
)
# Return response
return response.json() if response.headers.get("content-type", "").startswith("application/json") else response.text
except Exception as e:
logger.error(f"Error proxying to forecasting service: {e}")
raise

View File

@@ -1,156 +1,200 @@
# gateway/app/routes/tenant.py - COMPLETELY UPDATED
"""
Tenant routes for gateway - FIXED VERSION
Tenant routes for API Gateway - Handles all tenant-scoped endpoints
"""
from fastapi import APIRouter, Request, HTTPException
from fastapi import APIRouter, Request, Response, HTTPException, Path
from fastapi.responses import JSONResponse
import httpx
import logging
from typing import Optional
from app.core.config import settings
logger = logging.getLogger(__name__)
router = APIRouter()
# ================================================================
# TENANT MANAGEMENT ENDPOINTS
# ================================================================
@router.post("/register")
async def create_tenant(request: Request):
"""Proxy tenant creation to tenant service"""
try:
body = await request.body()
# ✅ FIX: Forward all headers AND add user context from gateway auth
headers = dict(request.headers)
headers.pop("host", None) # Remove host header
# ✅ ADD USER CONTEXT FROM GATEWAY AUTHENTICATION
# Gateway middleware already verified the token and added user to request.state
if hasattr(request.state, 'user'):
headers["X-User-ID"] = str(request.state.user.get("user_id"))
headers["X-User-Email"] = request.state.user.get("email", "")
headers["X-User-Role"] = request.state.user.get("role", "user")
# Add tenant ID if it exists
if hasattr(request.state, 'tenant_id') and request.state.tenant_id:
headers["X-Tenant-ID"] = str(request.state.tenant_id)
elif request.state.user.get("tenant_id"):
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
roles = request.state.user.get("roles", [])
if roles:
headers["X-User-Roles"] = ",".join(roles)
permissions = request.state.user.get("permissions", [])
if permissions:
headers["X-User-Permissions"] = ",".join(permissions)
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants/register",
content=body,
headers=headers
)
return JSONResponse(
status_code=response.status_code,
content=response.json()
)
except httpx.RequestError as e:
logger.error(f"Tenant service unavailable: {e}")
raise HTTPException(
status_code=503,
detail="Tenant service unavailable"
)
return await _proxy_to_tenant_service(request, "/api/v1/tenants/register")
@router.get("/")
async def get_tenants(request: Request):
"""Get tenants"""
@router.get("/{tenant_id}")
async def get_tenant(request: Request, tenant_id: str = Path(...)):
"""Get specific tenant details"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}")
@router.put("/{tenant_id}")
async def update_tenant(request: Request, tenant_id: str = Path(...)):
"""Update tenant details"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}")
@router.get("/{tenant_id}/members")
async def get_tenant_members(request: Request, tenant_id: str = Path(...)):
"""Get tenant members"""
return await _proxy_to_tenant_service(request, f"/api/v1/tenants/{tenant_id}/members")
# ================================================================
# TENANT-SCOPED DATA SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/sales/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_sales(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant sales requests to data service"""
target_path = f"/api/v1/tenants/{tenant_id}/sales/{path}".rstrip("/")
return await _proxy_to_data_service(request, target_path)
@router.api_route("/{tenant_id}/weather/{path:path}", methods=["GET", "POST", "OPTIONS"])
async def proxy_tenant_weather(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant weather requests to data service"""
target_path = f"/api/v1/tenants/{tenant_id}/weather/{path}".rstrip("/")
return await _proxy_to_data_service(request, target_path)
@router.api_route("/{tenant_id}/analytics/{path:path}", methods=["GET", "POST", "OPTIONS"])
async def proxy_tenant_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant analytics requests to data service"""
target_path = f"/api/v1/tenants/{tenant_id}/analytics/{path}".rstrip("/")
return await _proxy_to_data_service(request, target_path)
# ================================================================
# TENANT-SCOPED TRAINING SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/training/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_training(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant training requests to training service"""
target_path = f"/api/v1/tenants/{tenant_id}/training/{path}".rstrip("/")
return await _proxy_to_training_service(request, target_path)
@router.api_route("/{tenant_id}/models/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_models(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant model requests to training service"""
target_path = f"/api/v1/tenants/{tenant_id}/models/{path}".rstrip("/")
return await _proxy_to_training_service(request, target_path)
# ================================================================
# TENANT-SCOPED FORECASTING SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/forecasts/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_forecasts(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant forecast requests to forecasting service"""
target_path = f"/api/v1/tenants/{tenant_id}/forecasts/{path}".rstrip("/")
return await _proxy_to_forecasting_service(request, target_path)
@router.api_route("/{tenant_id}/predictions/{path:path}", methods=["GET", "POST", "OPTIONS"])
async def proxy_tenant_predictions(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant prediction requests to forecasting service"""
target_path = f"/api/v1/tenants/{tenant_id}/predictions/{path}".rstrip("/")
return await _proxy_to_forecasting_service(request, target_path)
# ================================================================
# TENANT-SCOPED NOTIFICATION SERVICE ENDPOINTS
# ================================================================
@router.api_route("/{tenant_id}/notifications/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_tenant_notifications(request: Request, tenant_id: str = Path(...), path: str = ""):
"""Proxy tenant notification requests to notification service"""
target_path = f"/api/v1/tenants/{tenant_id}/notifications/{path}".rstrip("/")
return await _proxy_to_notification_service(request, target_path)
# ================================================================
# PROXY HELPER FUNCTIONS
# ================================================================
async def _proxy_to_tenant_service(request: Request, target_path: str):
"""Proxy request to tenant service"""
return await _proxy_request(request, target_path, settings.TENANT_SERVICE_URL)
async def _proxy_to_data_service(request: Request, target_path: str):
"""Proxy request to data service"""
return await _proxy_request(request, target_path, settings.DATA_SERVICE_URL)
async def _proxy_to_training_service(request: Request, target_path: str):
"""Proxy request to training service"""
return await _proxy_request(request, target_path, settings.TRAINING_SERVICE_URL)
async def _proxy_to_forecasting_service(request: Request, target_path: str):
"""Proxy request to forecasting service"""
return await _proxy_request(request, target_path, settings.FORECASTING_SERVICE_URL)
async def _proxy_to_notification_service(request: Request, target_path: str):
"""Proxy request to notification service"""
return await _proxy_request(request, target_path, settings.NOTIFICATION_SERVICE_URL)
async def _proxy_request(request: Request, target_path: str, service_url: str):
"""Generic proxy function with enhanced error handling"""
# Handle OPTIONS requests directly for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400"
}
)
try:
# ✅ FIX: Same pattern for GET requests
url = f"{service_url}{target_path}"
# Forward headers and add user/tenant context
headers = dict(request.headers)
headers.pop("host", None)
# Add user context from gateway auth
if hasattr(request.state, 'user'):
headers["X-User-ID"] = str(request.state.user.get("user_id"))
headers["X-User-Email"] = request.state.user.get("email", "")
headers["X-User-Role"] = request.state.user.get("role", "user")
if hasattr(request.state, 'tenant_id') and request.state.tenant_id:
headers["X-Tenant-ID"] = str(request.state.tenant_id)
elif request.state.user.get("tenant_id"):
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
roles = request.state.user.get("roles", [])
if roles:
headers["X-User-Roles"] = ",".join(roles)
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
f"{settings.TENANT_SERVICE_URL}/api/v1/tenants",
headers=headers
)
return JSONResponse(
status_code=response.status_code,
content=response.json()
)
except httpx.RequestError as e:
logger.error(f"Tenant service unavailable: {e}")
raise HTTPException(
status_code=503,
detail="Tenant service unavailable"
)
# ✅ ADD: Generic proxy function like the data service has
async def _proxy_tenant_request(request: Request, target_path: str, method: str = None):
"""Proxy request to tenant service with user context"""
try:
url = f"{settings.TENANT_SERVICE_URL}{target_path}"
# Forward headers with user context
headers = dict(request.headers)
headers.pop("host", None)
# Add user context from gateway authentication
if hasattr(request.state, 'user'):
headers["X-User-ID"] = str(request.state.user.get("user_id"))
headers["X-User-Email"] = request.state.user.get("email", "")
headers["X-User-Role"] = request.state.user.get("role", "user")
if hasattr(request.state, 'tenant_id') and request.state.tenant_id:
headers["X-Tenant-ID"] = str(request.state.tenant_id)
elif request.state.user.get("tenant_id"):
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
roles = request.state.user.get("roles", [])
if roles:
headers["X-User-Roles"] = ",".join(roles)
# Get request body if present
body = None
request_method = method or request.method
if request_method in ["POST", "PUT", "PATCH"]:
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
# Add query parameters
params = dict(request.query_params)
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.request(
method=request_method,
method=request.method,
url=url,
headers=headers,
content=body,
params=dict(request.query_params)
params=params
)
# Handle different response types
if response.headers.get("content-type", "").startswith("application/json"):
try:
content = response.json()
except:
content = {"message": "Invalid JSON response from service"}
else:
content = response.text
return JSONResponse(
status_code=response.status_code,
content=response.json()
content=content
)
except httpx.TimeoutError:
logger.error(f"Timeout calling {service_url}{target_path}")
raise HTTPException(
status_code=504,
detail=f"Service timeout"
)
except httpx.RequestError as e:
logger.error(f"Tenant service unavailable: {e}")
logger.error(f"Request error calling {service_url}{target_path}: {e}")
raise HTTPException(
status_code=503,
detail="Tenant service unavailable"
detail=f"Service unavailable"
)
except Exception as e:
logger.error(f"Unexpected error proxying to {service_url}{target_path}: {e}")
raise HTTPException(
status_code=500,
detail="Internal gateway error"
)

View File

@@ -1,100 +0,0 @@
"""
Training routes for gateway - FIXED VERSION
"""
from fastapi import APIRouter, Request, HTTPException, Query, Response
from fastapi.responses import JSONResponse
import httpx
import logging
from typing import Optional
from app.core.config import settings
logger = logging.getLogger(__name__)
router = APIRouter()
async def _proxy_training_request(request: Request, target_path: str, method: str = None):
"""Proxy request to training service with user context"""
# Handle OPTIONS requests directly for CORS
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": settings.CORS_ORIGINS_LIST,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Tenant-ID",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "86400" # Cache preflight for 24 hours
}
)
try:
url = f"{settings.TRAINING_SERVICE_URL}{target_path}"
# Forward headers AND add user context from gateway auth
headers = dict(request.headers)
headers.pop("host", None) # Remove host header
# ✅ ADD USER CONTEXT FROM GATEWAY AUTHENTICATION
# Gateway middleware already verified the token and added user to request.state
if hasattr(request.state, 'user'):
headers["X-User-ID"] = str(request.state.user.get("user_id"))
headers["X-User-Email"] = request.state.user.get("email", "")
headers["X-Tenant-ID"] = str(request.state.user.get("tenant_id"))
headers["X-User-Roles"] = ",".join(request.state.user.get("roles", []))
headers["X-User-Permissions"] = ",".join(request.state.user.get("permissions", []))
# Get request body if present
body = None
request_method = method or request.method
if request_method in ["POST", "PUT", "PATCH"]:
body = await request.body()
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.request(
method=request_method,
url=url,
headers=headers,
content=body,
params=dict(request.query_params)
)
return JSONResponse(
status_code=response.status_code,
content=response.json()
)
except httpx.RequestError as e:
logger.error(f"Training service unavailable: {e}")
raise HTTPException(
status_code=503,
detail="Training service unavailable"
)
except Exception as e:
logger.error(f"Training service error: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/status/{training_job_id}")
async def get_training_status(training_job_id: str, request: Request):
"""Get training job status"""
return await _proxy_training_request(request, f"/training/status/{training_job_id}", "GET")
@router.get("/models")
async def get_trained_models(request: Request):
"""Get trained models"""
return await _proxy_training_request(request, "/training/models", "GET")
@router.get("/jobs")
async def get_training_jobs(
request: Request,
limit: Optional[int] = Query(10, ge=1, le=100),
offset: Optional[int] = Query(0, ge=0)
):
"""Get training jobs"""
return await _proxy_training_request(request, f"/training/jobs?limit={limit}&offset={offset}", "GET")
@router.post("/jobs")
async def start_training_job(request: Request):
"""Start a new training job - Proxy to training service"""
return await _proxy_training_request(request, "/training/jobs", "POST")