REFACTOR external service and improve websocket training
This commit is contained in:
@@ -255,28 +255,59 @@ async def events_stream(request: Request, tenant_id: str):
|
||||
@app.websocket("/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
||||
async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_id: str):
|
||||
"""
|
||||
WebSocket proxy that forwards connections directly to training service.
|
||||
Acts as a pure proxy - does NOT handle websocket logic, just forwards to training service.
|
||||
All auth, message handling, and business logic is in the training service.
|
||||
Simple WebSocket proxy with token verification only.
|
||||
Validates the token and forwards the connection to the training service.
|
||||
"""
|
||||
# Get token from query params (required for training service authentication)
|
||||
# Get token from query params
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
logger.warning(f"WebSocket proxy rejected - missing token for job {job_id}")
|
||||
logger.warning("WebSocket proxy rejected - missing token",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
await websocket.accept()
|
||||
await websocket.close(code=1008, reason="Authentication token required")
|
||||
return
|
||||
|
||||
# Accept the connection immediately
|
||||
# Verify token
|
||||
from shared.auth.jwt_handler import JWTHandler
|
||||
|
||||
jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
||||
|
||||
try:
|
||||
payload = jwt_handler.verify_token(token)
|
||||
if not payload or not payload.get('user_id'):
|
||||
logger.warning("WebSocket proxy rejected - invalid token",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id)
|
||||
await websocket.accept()
|
||||
await websocket.close(code=1008, reason="Invalid token")
|
||||
return
|
||||
|
||||
logger.info("WebSocket proxy - token verified",
|
||||
user_id=payload.get('user_id'),
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("WebSocket proxy - token verification failed",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
await websocket.accept()
|
||||
await websocket.close(code=1008, reason="Token verification failed")
|
||||
return
|
||||
|
||||
# Accept the connection
|
||||
await websocket.accept()
|
||||
|
||||
logger.info(f"Gateway proxying WebSocket to training service for job {job_id}, tenant {tenant_id}")
|
||||
|
||||
# Build WebSocket URL to training service - forward to the exact same path
|
||||
# Build WebSocket URL to training service
|
||||
training_service_base = settings.TRAINING_SERVICE_URL.rstrip('/')
|
||||
training_ws_url = training_service_base.replace('http://', 'ws://').replace('https://', 'wss://')
|
||||
training_ws_url = f"{training_ws_url}/api/v1/tenants/{tenant_id}/training/jobs/{job_id}/live?token={token}"
|
||||
|
||||
logger.info("Gateway proxying WebSocket to training service",
|
||||
job_id=job_id,
|
||||
training_ws_url=training_ws_url.replace(token, '***'))
|
||||
|
||||
training_ws = None
|
||||
|
||||
try:
|
||||
@@ -285,17 +316,15 @@ async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_
|
||||
|
||||
training_ws = await websockets.connect(
|
||||
training_ws_url,
|
||||
ping_interval=None, # Let training service handle heartbeat
|
||||
ping_timeout=None,
|
||||
close_timeout=10,
|
||||
open_timeout=30, # Allow time for training service to setup
|
||||
max_size=2**20,
|
||||
max_queue=32
|
||||
ping_interval=120, # Send ping every 2 minutes (tolerates long training operations)
|
||||
ping_timeout=60, # Wait up to 1 minute for pong (graceful timeout)
|
||||
close_timeout=60, # Increase close timeout for graceful shutdown
|
||||
open_timeout=30
|
||||
)
|
||||
|
||||
logger.info(f"Gateway connected to training service WebSocket for job {job_id}")
|
||||
logger.info("Gateway connected to training service WebSocket", job_id=job_id)
|
||||
|
||||
async def forward_to_training():
|
||||
async def forward_frontend_to_training():
|
||||
"""Forward messages from frontend to training service"""
|
||||
try:
|
||||
while training_ws and training_ws.open:
|
||||
@@ -304,55 +333,58 @@ async def websocket_training_progress(websocket: WebSocket, tenant_id: str, job_
|
||||
if data.get("type") == "websocket.receive":
|
||||
if "text" in data:
|
||||
await training_ws.send(data["text"])
|
||||
logger.debug(f"Gateway forwarded frontend->training: {data['text'][:100]}")
|
||||
elif "bytes" in data:
|
||||
await training_ws.send(data["bytes"])
|
||||
elif data.get("type") == "websocket.disconnect":
|
||||
logger.info(f"Frontend disconnected for job {job_id}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding frontend->training for job {job_id}: {e}")
|
||||
logger.debug("Frontend to training forward ended", error=str(e))
|
||||
|
||||
async def forward_to_frontend():
|
||||
async def forward_training_to_frontend():
|
||||
"""Forward messages from training service to frontend"""
|
||||
message_count = 0
|
||||
try:
|
||||
while training_ws and training_ws.open:
|
||||
message = await training_ws.recv()
|
||||
await websocket.send_text(message)
|
||||
logger.debug(f"Gateway forwarded training->frontend: {message[:100]}")
|
||||
message_count += 1
|
||||
|
||||
# Log every 10th message to track connectivity
|
||||
if message_count % 10 == 0:
|
||||
logger.debug("WebSocket proxy active",
|
||||
job_id=job_id,
|
||||
messages_forwarded=message_count)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding training->frontend for job {job_id}: {e}")
|
||||
logger.info("Training to frontend forward ended",
|
||||
job_id=job_id,
|
||||
messages_forwarded=message_count,
|
||||
error=str(e))
|
||||
|
||||
# Run both forwarding tasks concurrently
|
||||
await asyncio.gather(
|
||||
forward_to_training(),
|
||||
forward_to_frontend(),
|
||||
forward_frontend_to_training(),
|
||||
forward_training_to_frontend(),
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
logger.warning(f"Training service WebSocket closed for job {job_id}: {e}")
|
||||
except websockets.exceptions.WebSocketException as e:
|
||||
logger.error(f"WebSocket exception for job {job_id}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket proxy error for job {job_id}: {e}")
|
||||
logger.error("WebSocket proxy error", job_id=job_id, error=str(e))
|
||||
finally:
|
||||
# Cleanup
|
||||
if training_ws and not training_ws.closed:
|
||||
try:
|
||||
await training_ws.close()
|
||||
logger.info(f"Closed training service WebSocket for job {job_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing training service WebSocket for job {job_id}: {e}")
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if not websocket.client_state.name == 'DISCONNECTED':
|
||||
await websocket.close(code=1000, reason="Proxy closed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing frontend WebSocket for job {job_id}: {e}")
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.info(f"Gateway WebSocket proxy cleanup completed for job {job_id}")
|
||||
logger.info("WebSocket proxy connection closed", job_id=job_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
@@ -106,6 +106,12 @@ async def proxy_tenant_traffic(request: Request, tenant_id: str = Path(...), pat
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/traffic/{path}".rstrip("/")
|
||||
return await _proxy_to_external_service(request, target_path)
|
||||
|
||||
@router.api_route("/{tenant_id}/external/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_tenant_external(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant external service requests (v2.0 city-based optimized endpoints)"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/external/{path}".rstrip("/")
|
||||
return await _proxy_to_external_service(request, target_path)
|
||||
|
||||
@router.api_route("/{tenant_id}/analytics/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_tenant_analytics(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant analytics requests to sales service"""
|
||||
@@ -144,6 +150,12 @@ async def proxy_tenant_statistics(request: Request, tenant_id: str = Path(...)):
|
||||
# TENANT-SCOPED FORECASTING SERVICE ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.api_route("/{tenant_id}/forecasting/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_forecasting(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant forecasting requests to forecasting service"""
|
||||
target_path = f"/api/v1/tenants/{tenant_id}/forecasting/{path}".rstrip("/")
|
||||
return await _proxy_to_forecasting_service(request, target_path, tenant_id=tenant_id)
|
||||
|
||||
@router.api_route("/{tenant_id}/forecasts/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
async def proxy_tenant_forecasts(request: Request, tenant_id: str = Path(...), path: str = ""):
|
||||
"""Proxy tenant forecast requests to forecasting service"""
|
||||
|
||||
Reference in New Issue
Block a user