Add fixes to procurement logic and fix rel-time connections

This commit is contained in:
Urtzi Alfaro
2025-10-02 13:20:30 +02:00
parent c9d8d1d071
commit 1243c2ca6d
24 changed files with 4984 additions and 348 deletions

View File

@@ -4,6 +4,7 @@ Handles routing, authentication, rate limiting, and cross-cutting concerns
"""
import asyncio
import json
import structlog
from fastapi import FastAPI, Request, HTTPException, Depends, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
@@ -143,30 +144,29 @@ async def metrics():
# ================================================================
@app.get("/api/events")
async def events_stream(request: Request, token: str):
"""Server-Sent Events stream for real-time notifications"""
async def events_stream(request: Request, tenant_id: str):
"""
Server-Sent Events stream for real-time notifications.
Authentication is handled by auth middleware via query param token.
User context is available in request.state.user (injected by middleware).
Tenant ID is provided by the frontend as a query parameter.
"""
global redis_client
if not redis_client:
raise HTTPException(status_code=503, detail="SSE service unavailable")
# Extract tenant_id from JWT token (basic extraction - you might want proper JWT validation)
try:
import jwt
import base64
import json as json_lib
# Decode JWT without verification for tenant_id (in production, verify the token)
payload = jwt.decode(token, options={"verify_signature": False})
tenant_id = payload.get('tenant_id')
user_id = payload.get('user_id')
if not tenant_id:
raise HTTPException(status_code=401, detail="Invalid token: missing tenant_id")
except Exception as e:
logger.error(f"Token decode error: {e}")
raise HTTPException(status_code=401, detail="Invalid token")
# Extract user context from request state (set by auth middleware)
user_context = request.state.user
user_id = user_context.get('user_id')
email = user_context.get('email')
# Validate tenant_id parameter
if not tenant_id:
raise HTTPException(status_code=400, detail="tenant_id query parameter is required")
logger.info(f"SSE connection request for user {email}, tenant {tenant_id}")
logger.info(f"SSE connection established for tenant: {tenant_id}")
@@ -181,7 +181,7 @@ async def events_stream(request: Request, token: str):
# Send initial connection event
yield f"event: connection\n"
yield f"data: {json_lib.dumps({'type': 'connected', 'message': 'SSE connection established', 'timestamp': time.time()})}\n\n"
yield f"data: {json.dumps({'type': 'connected', 'message': 'SSE connection established', 'timestamp': time.time()})}\n\n"
heartbeat_counter = 0
@@ -197,8 +197,8 @@ async def events_stream(request: Request, token: str):
if message and message['type'] == 'message':
# Forward the alert/notification from Redis
alert_data = json_lib.loads(message['data'])
alert_data = json.loads(message['data'])
# Determine event type based on alert data
event_type = "notification"
if alert_data.get('item_type') == 'alert':
@@ -208,9 +208,9 @@ async def events_stream(request: Request, token: str):
event_type = "notification"
elif alert_data.get('item_type') == 'recommendation':
event_type = "notification"
yield f"event: {event_type}\n"
yield f"data: {json_lib.dumps(alert_data)}\n\n"
yield f"data: {json.dumps(alert_data)}\n\n"
logger.debug(f"SSE message sent to tenant {tenant_id}: {alert_data.get('title')}")
@@ -219,7 +219,7 @@ async def events_stream(request: Request, token: str):
heartbeat_counter += 1
if heartbeat_counter >= 10:
yield f"event: heartbeat\n"
yield f"data: {json_lib.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n"
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': time.time()})}\n\n"
heartbeat_counter = 0
except asyncio.CancelledError:

View File

@@ -131,10 +131,29 @@ class AuthMiddleware(BaseHTTPMiddleware):
return any(path.startswith(route) for route in PUBLIC_ROUTES)
def _extract_token(self, request: Request) -> Optional[str]:
"""Extract JWT token from Authorization header"""
"""
Extract JWT token from Authorization header or query params for SSE.
For SSE endpoints (/api/events), browsers' EventSource API cannot send
custom headers, so we must accept token as query parameter.
For all other routes, token must be in Authorization header (more secure).
Security note: Query param tokens are logged. Use short expiry and filter logs.
"""
# SSE endpoint exception: token in query param (EventSource API limitation)
if request.url.path == "/api/events":
token = request.query_params.get("token")
if token:
logger.debug("Token extracted from query param for SSE endpoint")
return token
logger.warning("SSE request missing token in query param")
return None
# Standard authentication: Authorization header for all other routes
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.split(" ")[1]
return None
async def _verify_token(self, token: str, request: Request = None) -> Optional[Dict[str, Any]]: