Add base kubernetes support final fix 4
This commit is contained in:
@@ -249,60 +249,193 @@ async def events_stream(request: Request, token: str):
|
||||
|
||||
@app.websocket("/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
||||
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
|
||||
"""WebSocket proxy that forwards connections directly to training service"""
|
||||
"""WebSocket proxy that forwards connections directly to training service with enhanced token validation"""
|
||||
await websocket.accept()
|
||||
|
||||
|
||||
# Get token from query params
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
logger.warning(f"WebSocket connection rejected - missing token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Authentication token required")
|
||||
return
|
||||
|
||||
|
||||
# Validate token using auth middleware
|
||||
from app.middleware.auth import jwt_handler
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload:
|
||||
logger.warning(f"WebSocket connection rejected - invalid token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Invalid authentication token")
|
||||
return
|
||||
|
||||
# Check token expiration
|
||||
import time
|
||||
if payload.get('exp', 0) < time.time():
|
||||
logger.warning(f"WebSocket connection rejected - expired token for job {job_id}")
|
||||
await websocket.close(code=1008, reason="Token expired")
|
||||
return
|
||||
|
||||
logger.info(f"WebSocket token validated for user {payload.get('email', 'unknown')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket token validation failed for job {job_id}: {e}")
|
||||
await websocket.close(code=1008, reason="Token validation failed")
|
||||
return
|
||||
|
||||
logger.info(f"Proxying WebSocket connection to training service for job {job_id}, tenant {tenant_id}")
|
||||
|
||||
|
||||
# Build WebSocket URL to training service
|
||||
training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/')
|
||||
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
|
||||
training_ws_url = f"{training_ws_url}/api/v1/ws/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
|
||||
|
||||
|
||||
training_ws = None
|
||||
heartbeat_task = None
|
||||
|
||||
try:
|
||||
# Connect to training service WebSocket
|
||||
# Connect to training service WebSocket with proper timeout configuration
|
||||
import websockets
|
||||
async with websockets.connect(training_ws_url) as training_ws:
|
||||
logger.info(f"Connected to training service WebSocket for job {job_id}")
|
||||
|
||||
async def forward_to_training():
|
||||
"""Forward messages from frontend to training service"""
|
||||
|
||||
# Configure timeouts to coordinate with frontend (30s heartbeat) and training service
|
||||
# DISABLE gateway-level ping to avoid dual-ping conflicts - let frontend handle ping/pong
|
||||
training_ws = await websockets.connect(
|
||||
training_ws_url,
|
||||
ping_interval=None, # DISABLED: Let frontend handle ping/pong via message forwarding
|
||||
ping_timeout=None, # DISABLED: No independent ping mechanism
|
||||
close_timeout=15, # Reasonable close timeout
|
||||
max_size=2**20, # 1MB max message size
|
||||
max_queue=32 # Max queued messages
|
||||
)
|
||||
|
||||
logger.info(f"Connected to training service WebSocket for job {job_id} with gateway ping DISABLED (frontend handles ping/pong)")
|
||||
|
||||
# Track connection state properly due to FastAPI WebSocket state propagation bug
|
||||
connection_alive = True
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
async def check_connection_health():
|
||||
"""Monitor connection health based on activity timestamps only - no WebSocket interference"""
|
||||
nonlocal connection_alive, last_activity
|
||||
|
||||
while connection_alive:
|
||||
try:
|
||||
async for message in websocket.iter_text():
|
||||
await asyncio.sleep(30) # Check every 30 seconds (aligned with frontend heartbeat)
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Check if we haven't received any activity for too long
|
||||
# Frontend sends ping every 30s, so 90s = 3 missed pings before considering dead
|
||||
if current_time - last_activity > 90:
|
||||
logger.warning(f"No frontend activity for 90s on job {job_id} - connection may be dead")
|
||||
# Don't forcibly close - let the forwarding loops handle actual connection issues
|
||||
# This is just monitoring/logging now
|
||||
else:
|
||||
logger.debug(f"Connection health OK for job {job_id} - last activity {int(current_time - last_activity)}s ago")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Connection health monitoring error for job {job_id}: {e}")
|
||||
break
|
||||
|
||||
async def forward_to_training():
|
||||
"""Forward messages from frontend to training service with proper error handling"""
|
||||
nonlocal connection_alive, last_activity
|
||||
|
||||
try:
|
||||
while connection_alive and training_ws and training_ws.open:
|
||||
try:
|
||||
# Use longer timeout to avoid conflicts with frontend 30s heartbeat
|
||||
# Frontend sends ping every 30s, so we need to allow for some latency
|
||||
message = await asyncio.wait_for(websocket.receive_text(), timeout=45.0)
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
# Forward the message to training service
|
||||
await training_ws.send(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding to training service: {e}")
|
||||
|
||||
async def forward_to_frontend():
|
||||
"""Forward messages from training service to frontend"""
|
||||
try:
|
||||
async for message in training_ws:
|
||||
logger.debug(f"Forwarded message to training service for job {job_id}: {message[:100]}...")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No message received in 45 seconds, continue loop
|
||||
# This allows for frontend 30s heartbeat + network latency + processing time
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving from frontend for job {job_id}: {e}")
|
||||
connection_alive = False
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in forward_to_training for job {job_id}: {e}")
|
||||
connection_alive = False
|
||||
|
||||
async def forward_to_frontend():
|
||||
"""Forward messages from training service to frontend with proper error handling"""
|
||||
nonlocal connection_alive, last_activity
|
||||
|
||||
try:
|
||||
while connection_alive and training_ws and training_ws.open:
|
||||
try:
|
||||
# Use coordinated timeout - training service expects messages every 60s
|
||||
# This should be longer than training service timeout to avoid premature closure
|
||||
message = await asyncio.wait_for(training_ws.recv(), timeout=75.0)
|
||||
last_activity = asyncio.get_event_loop().time()
|
||||
|
||||
# Forward the message to frontend
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding to frontend: {e}")
|
||||
|
||||
# Run both forwarding tasks concurrently
|
||||
logger.debug(f"Forwarded message to frontend for job {job_id}: {message[:100]}...")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No message received in 75 seconds, continue loop
|
||||
# Training service sends heartbeats, so this indicates potential issues
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving from training service for job {job_id}: {e}")
|
||||
connection_alive = False
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in forward_to_frontend for job {job_id}: {e}")
|
||||
connection_alive = False
|
||||
|
||||
# Start connection health monitoring
|
||||
heartbeat_task = asyncio.create_task(check_connection_health())
|
||||
|
||||
# Run both forwarding tasks concurrently with proper error handling
|
||||
try:
|
||||
await asyncio.gather(
|
||||
forward_to_training(),
|
||||
forward_to_frontend(),
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in WebSocket forwarding tasks for job {job_id}: {e}")
|
||||
finally:
|
||||
connection_alive = False
|
||||
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
logger.warning(f"Training service WebSocket connection closed for job {job_id}: {e}")
|
||||
except websockets.exceptions.WebSocketException as e:
|
||||
logger.error(f"WebSocket exception for job {job_id}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket proxy error for job {job_id}: {e}")
|
||||
try:
|
||||
await websocket.close(code=1011, reason="Training service connection failed")
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
logger.info(f"WebSocket proxy closed for job {job_id}")
|
||||
# Cleanup
|
||||
if heartbeat_task and not heartbeat_task.done():
|
||||
heartbeat_task.cancel()
|
||||
try:
|
||||
await heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if training_ws and not training_ws.closed:
|
||||
try:
|
||||
await training_ws.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing training service WebSocket for job {job_id}: {e}")
|
||||
|
||||
try:
|
||||
if not websocket.client_state.name == 'DISCONNECTED':
|
||||
await websocket.close(code=1000, reason="Proxy connection closed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing frontend WebSocket for job {job_id}: {e}")
|
||||
|
||||
logger.info(f"WebSocket proxy cleanup completed for job {job_id}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@@ -67,7 +67,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
|
||||
# ✅ STEP 2: Verify token and get user context
|
||||
user_context = await self._verify_token(token)
|
||||
user_context = await self._verify_token(token, request)
|
||||
if not user_context:
|
||||
logger.warning(f"Invalid token for route: {request.url.path}")
|
||||
return JSONResponse(
|
||||
@@ -117,7 +117,14 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
tenant_id=tenant_id,
|
||||
path=request.url.path)
|
||||
|
||||
return await call_next(request)
|
||||
# Process the request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add token expiry warning header if token is near expiry
|
||||
if hasattr(request.state, 'token_near_expiry') and request.state.token_near_expiry:
|
||||
response.headers["X-Token-Refresh-Suggested"] = "true"
|
||||
|
||||
return response
|
||||
|
||||
def _is_public_route(self, path: str) -> bool:
|
||||
"""Check if route requires authentication"""
|
||||
@@ -130,7 +137,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
return auth_header.split(" ")[1]
|
||||
return None
|
||||
|
||||
async def _verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
async def _verify_token(self, token: str, request: Request = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Verify JWT token with improved fallback strategy
|
||||
FIXED: Better error handling and token structure validation
|
||||
@@ -141,6 +148,17 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if payload and self._validate_token_payload(payload):
|
||||
logger.debug("Token validated locally")
|
||||
|
||||
# Check if token is near expiry and set flag for response header
|
||||
if request:
|
||||
import time
|
||||
exp_time = payload.get("exp", 0)
|
||||
current_time = time.time()
|
||||
time_until_expiry = exp_time - current_time
|
||||
|
||||
if time_until_expiry < 300: # 5 minutes
|
||||
request.state.token_near_expiry = True
|
||||
|
||||
# Convert JWT payload to user context format
|
||||
return self._jwt_payload_to_user_context(payload)
|
||||
except Exception as e:
|
||||
@@ -177,18 +195,26 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
required_fields = ["user_id", "email", "exp", "type"]
|
||||
missing_fields = [field for field in required_fields if field not in payload]
|
||||
|
||||
|
||||
if missing_fields:
|
||||
logger.warning(f"Token payload missing fields: {missing_fields}")
|
||||
return False
|
||||
|
||||
|
||||
# Validate token type
|
||||
token_type = payload.get("type")
|
||||
if token_type not in ["access", "service"]:
|
||||
logger.warning(f"Invalid token type: {payload.get('type')}")
|
||||
return False
|
||||
|
||||
|
||||
# Check if token is near expiry (within 5 minutes) and log warning
|
||||
import time
|
||||
exp_time = payload.get("exp", 0)
|
||||
current_time = time.time()
|
||||
time_until_expiry = exp_time - current_time
|
||||
|
||||
if time_until_expiry < 300: # 5 minutes
|
||||
logger.warning(f"Token expires in {int(time_until_expiry)} seconds for user {payload.get('email')}")
|
||||
|
||||
return True
|
||||
|
||||
def _jwt_payload_to_user_context(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
Reference in New Issue
Block a user