Add fixes to procurement logic and fix rel-time connections
This commit is contained in:
@@ -4,6 +4,7 @@ Handles routing, authentication, rate limiting, and cross-cutting concerns
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import structlog
|
||||
from fastapi import FastAPI, Request, HTTPException, Depends, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -143,30 +144,29 @@ async def metrics():
|
||||
# ================================================================
|
||||
|
||||
@app.get("/api/events")
|
||||
async def events_stream(request: Request, token: str):
|
||||
"""Server-Sent Events stream for real-time notifications"""
|
||||
async def events_stream(request: Request, tenant_id: str):
|
||||
"""
|
||||
Server-Sent Events stream for real-time notifications.
|
||||
|
||||
Authentication is handled by auth middleware via query param token.
|
||||
User context is available in request.state.user (injected by middleware).
|
||||
Tenant ID is provided by the frontend as a query parameter.
|
||||
"""
|
||||
global redis_client
|
||||
|
||||
|
||||
if not redis_client:
|
||||
raise HTTPException(status_code=503, detail="SSE service unavailable")
|
||||
|
||||
# Extract tenant_id from JWT token (basic extraction - you might want proper JWT validation)
|
||||
try:
|
||||
import jwt
|
||||
import base64
|
||||
import json as json_lib
|
||||
|
||||
# Decode JWT without verification for tenant_id (in production, verify the token)
|
||||
payload = jwt.decode(token, options={"verify_signature": False})
|
||||
tenant_id = payload.get('tenant_id')
|
||||
user_id = payload.get('user_id')
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid token: missing tenant_id")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token decode error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
# Extract user context from request state (set by auth middleware)
|
||||
user_context = request.state.user
|
||||
user_id = user_context.get('user_id')
|
||||
email = user_context.get('email')
|
||||
|
||||
# Validate tenant_id parameter
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="tenant_id query parameter is required")
|
||||
|
||||
logger.info(f"SSE connection request for user {email}, tenant {tenant_id}")
|
||||
|
||||
logger.info(f"SSE connection established for tenant: {tenant_id}")
|
||||
|
||||
@@ -181,7 +181,7 @@ async def events_stream(request: Request, token: str):
|
||||
|
||||
# Send initial connection event
|
||||
yield f"event: connection\n"
|
||||
yield f"data: {json_lib.dumps({'type': 'connected', 'message': 'SSE connection established', 'timestamp': time.time()})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'timestamp': time.time()})}\n\n"
|
||||
|
||||
heartbeat_counter = 0
|
||||
|
||||
@@ -197,8 +197,8 @@ async def events_stream(request: Request, token: str):
|
||||
|
||||
if message and message['type'] == 'message':
|
||||
# Forward the alert/notification from Redis
|
||||
alert_data = json_lib.loads(message['data'])
|
||||
|
||||
alert_data = json.loads(message['data'])
|
||||
|
||||
# Determine event type based on alert data
|
||||
event_type = "notification"
|
||||
if alert_data.get('item_type') == 'alert':
|
||||
@@ -208,9 +208,9 @@ async def events_stream(request: Request, token: str):
|
||||
event_type = "notification"
|
||||
elif alert_data.get('item_type') == 'recommendation':
|
||||
event_type = "notification"
|
||||
|
||||
|
||||
yield f"event: {event_type}\n"
|
||||
yield f"data: {json_lib.dumps(alert_data)}\n\n"
|
||||
yield f"data: {json.dumps(alert_data)}\n\n"
|
||||
|
||||
logger.debug(f"SSE message sent to tenant {tenant_id}: {alert_data.get('title')}")
|
||||
|
||||
@@ -219,7 +219,7 @@ async def events_stream(request: Request, token: str):
|
||||
heartbeat_counter += 1
|
||||
if heartbeat_counter >= 10:
|
||||
yield f"event: heartbeat\n"
|
||||
yield f"data: {json_lib.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n"
|
||||
heartbeat_counter = 0
|
||||
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@@ -131,10 +131,29 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
return any(path.startswith(route) for route in PUBLIC_ROUTES)
|
||||
|
||||
def _extract_token(self, request: Request) -> Optional[str]:
|
||||
"""Extract JWT token from Authorization header"""
|
||||
"""
|
||||
Extract JWT token from Authorization header or query params for SSE.
|
||||
|
||||
For SSE endpoints (/api/events), browsers' EventSource API cannot send
|
||||
custom headers, so we must accept token as query parameter.
|
||||
For all other routes, token must be in Authorization header (more secure).
|
||||
|
||||
Security note: Query param tokens are logged. Use short expiry and filter logs.
|
||||
"""
|
||||
# SSE endpoint exception: token in query param (EventSource API limitation)
|
||||
if request.url.path == "/api/events":
|
||||
token = request.query_params.get("token")
|
||||
if token:
|
||||
logger.debug("Token extracted from query param for SSE endpoint")
|
||||
return token
|
||||
logger.warning("SSE request missing token in query param")
|
||||
return None
|
||||
|
||||
# Standard authentication: Authorization header for all other routes
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header.split(" ")[1]
|
||||
|
||||
return None
|
||||
|
||||
async def _verify_token(self, token: str, request: Request = None) -> Optional[Dict[str, Any]]:
|
||||
|
||||
Reference in New Issue
Block a user