2025-07-31 15:34:35 +02:00
|
|
|
# services/training/app/api/websocket.py
|
|
|
|
|
"""
|
|
|
|
|
WebSocket endpoints for real-time training progress updates
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
|
|
import asyncio
|
|
|
|
|
from typing import Dict, Any
|
|
|
|
|
from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException
|
|
|
|
|
from fastapi.routing import APIRouter
|
2025-08-01 20:43:02 +02:00
|
|
|
import datetime
|
2025-07-31 15:34:35 +02:00
|
|
|
|
2025-08-01 17:55:14 +02:00
|
|
|
import structlog
|
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
|
|
2025-07-31 15:34:35 +02:00
|
|
|
from app.services.messaging import training_publisher
|
|
|
|
|
from shared.auth.decorators import (
|
|
|
|
|
get_current_user_dep,
|
|
|
|
|
get_current_tenant_id_dep
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create WebSocket router
|
|
|
|
|
websocket_router = APIRouter()
|
|
|
|
|
|
|
|
|
|
class ConnectionManager:
|
|
|
|
|
"""Manage WebSocket connections for training progress"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.active_connections: Dict[str, Dict[str, WebSocket]] = {}
|
|
|
|
|
# Structure: {job_id: {connection_id: websocket}}
|
|
|
|
|
|
|
|
|
|
async def connect(self, websocket: WebSocket, job_id: str, connection_id: str):
|
|
|
|
|
"""Accept WebSocket connection and register it"""
|
|
|
|
|
await websocket.accept()
|
|
|
|
|
|
|
|
|
|
if job_id not in self.active_connections:
|
|
|
|
|
self.active_connections[job_id] = {}
|
|
|
|
|
|
|
|
|
|
self.active_connections[job_id][connection_id] = websocket
|
|
|
|
|
logger.info(f"WebSocket connected for job {job_id}, connection {connection_id}")
|
|
|
|
|
|
|
|
|
|
def disconnect(self, job_id: str, connection_id: str):
|
|
|
|
|
"""Remove WebSocket connection"""
|
|
|
|
|
if job_id in self.active_connections:
|
|
|
|
|
self.active_connections[job_id].pop(connection_id, None)
|
|
|
|
|
if not self.active_connections[job_id]:
|
|
|
|
|
del self.active_connections[job_id]
|
|
|
|
|
|
|
|
|
|
logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}")
|
|
|
|
|
|
|
|
|
|
async def send_to_job(self, job_id: str, message: dict):
|
2025-08-01 17:55:14 +02:00
|
|
|
"""Send message to all connections for a specific job with better error handling"""
|
2025-07-31 15:34:35 +02:00
|
|
|
if job_id not in self.active_connections:
|
2025-08-01 17:55:14 +02:00
|
|
|
logger.debug(f"No active connections for job {job_id}")
|
2025-07-31 15:34:35 +02:00
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Send to all connections for this job
|
|
|
|
|
disconnected_connections = []
|
|
|
|
|
|
|
|
|
|
for connection_id, websocket in self.active_connections[job_id].items():
|
|
|
|
|
try:
|
|
|
|
|
await websocket.send_json(message)
|
2025-08-01 17:55:14 +02:00
|
|
|
logger.debug(f"📤 Sent {message.get('type', 'unknown')} to connection {connection_id}")
|
2025-07-31 15:34:35 +02:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to send message to connection {connection_id}: {e}")
|
|
|
|
|
disconnected_connections.append(connection_id)
|
|
|
|
|
|
|
|
|
|
# Clean up disconnected connections
|
|
|
|
|
for connection_id in disconnected_connections:
|
|
|
|
|
self.disconnect(job_id, connection_id)
|
2025-08-01 17:55:14 +02:00
|
|
|
|
|
|
|
|
# Log successful sends
|
|
|
|
|
active_count = len(self.active_connections.get(job_id, {}))
|
|
|
|
|
if active_count > 0:
|
|
|
|
|
logger.info(f"📡 Sent {message.get('type', 'unknown')} message to {active_count} connection(s) for job {job_id}")
|
2025-07-31 15:34:35 +02:00
|
|
|
|
|
|
|
|
# Global connection manager
|
|
|
|
|
connection_manager = ConnectionManager()
|
|
|
|
|
|
|
|
|
|
@websocket_router.websocket("/tenants/{tenant_id}/training/jobs/{job_id}/live")
|
|
|
|
|
async def training_progress_websocket(
|
|
|
|
|
websocket: WebSocket,
|
|
|
|
|
tenant_id: str,
|
|
|
|
|
job_id: str
|
|
|
|
|
):
|
|
|
|
|
connection_id = f"{tenant_id}_{id(websocket)}"
|
|
|
|
|
|
|
|
|
|
await connection_manager.connect(websocket, job_id, connection_id)
|
2025-08-01 17:55:14 +02:00
|
|
|
logger.info(f"WebSocket connection established for job {job_id}")
|
2025-07-31 15:34:35 +02:00
|
|
|
|
|
|
|
|
consumer_task = None
|
2025-08-01 18:13:34 +02:00
|
|
|
training_completed = False
|
2025-07-31 15:34:35 +02:00
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Start RabbitMQ consumer
|
|
|
|
|
consumer_task = asyncio.create_task(
|
|
|
|
|
setup_rabbitmq_consumer_for_job(job_id, tenant_id)
|
|
|
|
|
)
|
|
|
|
|
|
2025-08-01 17:55:14 +02:00
|
|
|
await asyncio.sleep(0.5)
|
|
|
|
|
|
2025-08-01 18:13:34 +02:00
|
|
|
# Send initial status
|
2025-07-31 15:34:35 +02:00
|
|
|
try:
|
|
|
|
|
initial_status = await get_current_job_status(job_id, tenant_id)
|
|
|
|
|
if initial_status:
|
|
|
|
|
await websocket.send_json({
|
|
|
|
|
"type": "initial_status",
|
|
|
|
|
"job_id": job_id,
|
|
|
|
|
"data": initial_status
|
|
|
|
|
})
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to send initial status: {e}")
|
|
|
|
|
|
2025-08-01 17:55:14 +02:00
|
|
|
last_activity = asyncio.get_event_loop().time()
|
|
|
|
|
|
|
|
|
|
while not training_completed:
|
2025-07-31 15:34:35 +02:00
|
|
|
try:
|
2025-08-01 18:13:34 +02:00
|
|
|
# FIXED: Use receive() instead of receive_text()
|
2025-08-01 17:55:14 +02:00
|
|
|
try:
|
2025-08-01 18:13:34 +02:00
|
|
|
data = await asyncio.wait_for(websocket.receive(), timeout=30.0)
|
2025-08-01 17:55:14 +02:00
|
|
|
last_activity = asyncio.get_event_loop().time()
|
|
|
|
|
|
2025-08-01 18:13:34 +02:00
|
|
|
# Handle different message types
|
|
|
|
|
if data["type"] == "websocket.receive":
|
|
|
|
|
if "text" in data:
|
|
|
|
|
message_text = data["text"]
|
|
|
|
|
if message_text == "ping":
|
|
|
|
|
await websocket.send_text("pong")
|
|
|
|
|
logger.debug(f"Text ping received from job {job_id}")
|
|
|
|
|
elif message_text == "get_status":
|
|
|
|
|
current_status = await get_current_job_status(job_id, tenant_id)
|
|
|
|
|
if current_status:
|
|
|
|
|
await websocket.send_json({
|
|
|
|
|
"type": "current_status",
|
|
|
|
|
"job_id": job_id,
|
|
|
|
|
"data": current_status
|
|
|
|
|
})
|
|
|
|
|
elif message_text == "close":
|
|
|
|
|
logger.info(f"Client requested connection close for job {job_id}")
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
elif "bytes" in data:
|
|
|
|
|
# Handle binary messages (WebSocket ping frames)
|
|
|
|
|
await websocket.send_text("pong")
|
|
|
|
|
logger.debug(f"Binary ping received for job {job_id}")
|
|
|
|
|
|
|
|
|
|
elif data["type"] == "websocket.disconnect":
|
|
|
|
|
logger.info(f"WebSocket disconnect message received for job {job_id}")
|
2025-08-01 17:55:14 +02:00
|
|
|
break
|
|
|
|
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
# No message received in 30 seconds - send heartbeat
|
|
|
|
|
current_time = asyncio.get_event_loop().time()
|
2025-08-01 18:13:34 +02:00
|
|
|
if current_time - last_activity > 60:
|
2025-08-01 17:55:14 +02:00
|
|
|
logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat")
|
|
|
|
|
|
|
|
|
|
try:
|
2025-07-31 15:34:35 +02:00
|
|
|
await websocket.send_json({
|
2025-08-01 17:55:14 +02:00
|
|
|
"type": "heartbeat",
|
2025-07-31 15:34:35 +02:00
|
|
|
"job_id": job_id,
|
2025-08-01 20:43:02 +02:00
|
|
|
"timestamp": str(datetime.datetime.now())
|
2025-07-31 15:34:35 +02:00
|
|
|
})
|
2025-08-01 17:55:14 +02:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
|
|
|
|
|
break
|
2025-07-31 15:34:35 +02:00
|
|
|
|
|
|
|
|
except WebSocketDisconnect:
|
2025-08-01 17:55:14 +02:00
|
|
|
logger.info(f"WebSocket client disconnected for job {job_id}")
|
2025-07-31 15:34:35 +02:00
|
|
|
break
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"WebSocket error for job {job_id}: {e}")
|
2025-08-01 18:13:34 +02:00
|
|
|
# Check if it's the specific "cannot call receive" error
|
|
|
|
|
if "Cannot call" in str(e) and "disconnect message" in str(e):
|
|
|
|
|
logger.error(f"FastAPI WebSocket disconnect error - connection already closed")
|
|
|
|
|
break
|
|
|
|
|
# Don't break immediately for other errors - try to recover
|
2025-08-01 17:55:14 +02:00
|
|
|
await asyncio.sleep(1)
|
|
|
|
|
|
|
|
|
|
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
|
2025-07-31 15:34:35 +02:00
|
|
|
|
2025-08-01 17:55:14 +02:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
|
|
|
|
|
|
2025-07-31 15:34:35 +02:00
|
|
|
finally:
|
2025-08-01 17:55:14 +02:00
|
|
|
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
|
2025-07-31 15:34:35 +02:00
|
|
|
connection_manager.disconnect(job_id, connection_id)
|
|
|
|
|
|
|
|
|
|
if consumer_task and not consumer_task.done():
|
2025-08-01 17:55:14 +02:00
|
|
|
if training_completed:
|
|
|
|
|
logger.info(f"Training completed, cancelling consumer for job {job_id}")
|
|
|
|
|
consumer_task.cancel()
|
|
|
|
|
else:
|
|
|
|
|
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
|
|
|
|
|
|
2025-07-31 15:34:35 +02:00
|
|
|
try:
|
|
|
|
|
await consumer_task
|
|
|
|
|
except asyncio.CancelledError:
|
2025-08-01 17:55:14 +02:00
|
|
|
logger.info(f"Consumer task cancelled for job {job_id}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Consumer task error for job {job_id}: {e}")
|
|
|
|
|
|
2025-07-31 15:34:35 +02:00
|
|
|
|
|
|
|
|
async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
|
|
|
|
|
"""Set up RabbitMQ consumer to listen for training events for a specific job"""
|
|
|
|
|
|
2025-08-01 17:55:14 +02:00
|
|
|
logger.info(f"🚀 Setting up RabbitMQ consumer for job {job_id}")
|
|
|
|
|
|
2025-07-31 15:34:35 +02:00
|
|
|
try:
|
|
|
|
|
# Create a unique queue for this WebSocket connection
|
|
|
|
|
queue_name = f"websocket_training_{job_id}_{tenant_id}"
|
|
|
|
|
|
|
|
|
|
async def handle_training_message(message):
|
|
|
|
|
"""Handle incoming RabbitMQ messages and forward to WebSocket"""
|
|
|
|
|
try:
|
|
|
|
|
# Parse the message
|
|
|
|
|
body = message.body.decode()
|
|
|
|
|
data = json.loads(body)
|
|
|
|
|
|
2025-08-01 17:55:14 +02:00
|
|
|
logger.debug(f"🔍 Received message for job {job_id}: {data.get('event_type', 'unknown')}")
|
|
|
|
|
|
2025-07-31 15:34:35 +02:00
|
|
|
# Extract event data
|
|
|
|
|
event_type = data.get("event_type", "unknown")
|
|
|
|
|
event_data = data.get("data", {})
|
|
|
|
|
|
|
|
|
|
# Only process messages for this specific job
|
2025-08-01 17:55:14 +02:00
|
|
|
message_job_id = event_data.get("job_id") if event_data else None
|
|
|
|
|
if message_job_id != job_id:
|
|
|
|
|
logger.debug(f"⏭️ Ignoring message for different job: {message_job_id}")
|
2025-07-31 15:34:35 +02:00
|
|
|
await message.ack()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Transform RabbitMQ message to WebSocket message format
|
|
|
|
|
websocket_message = {
|
|
|
|
|
"type": map_event_type_to_websocket_type(event_type),
|
|
|
|
|
"job_id": job_id,
|
|
|
|
|
"timestamp": data.get("timestamp"),
|
|
|
|
|
"data": event_data
|
|
|
|
|
}
|
|
|
|
|
|
2025-08-01 17:55:14 +02:00
|
|
|
logger.info(f"📤 Forwarding {event_type} message to WebSocket clients for job {job_id}")
|
|
|
|
|
|
2025-07-31 15:34:35 +02:00
|
|
|
# Send to all WebSocket connections for this job
|
|
|
|
|
await connection_manager.send_to_job(job_id, websocket_message)
|
|
|
|
|
|
2025-08-01 17:55:14 +02:00
|
|
|
# Check if this is a completion message
|
|
|
|
|
if event_type in ["training.completed", "training.failed"]:
|
|
|
|
|
logger.info(f"🎯 Training completion detected for job {job_id}: {event_type}")
|
|
|
|
|
# Mark training as completed (you might want to store this in a global state)
|
|
|
|
|
# For now, we'll let the WebSocket handle this through the message
|
|
|
|
|
|
2025-07-31 15:34:35 +02:00
|
|
|
# Acknowledge the message
|
|
|
|
|
await message.ack()
|
|
|
|
|
|
2025-08-01 17:55:14 +02:00
|
|
|
logger.debug(f"✅ Successfully processed {event_type} for job {job_id}")
|
2025-07-31 15:34:35 +02:00
|
|
|
|
|
|
|
|
except Exception as e:
|
2025-08-01 17:55:14 +02:00
|
|
|
logger.error(f"❌ Error handling training message for job {job_id}: {e}")
|
|
|
|
|
import traceback
|
|
|
|
|
logger.error(f"💥 Traceback: {traceback.format_exc()}")
|
2025-07-31 15:34:35 +02:00
|
|
|
await message.nack(requeue=False)
|
|
|
|
|
|
2025-08-01 17:55:14 +02:00
|
|
|
# Check if training_publisher is connected
|
|
|
|
|
if not training_publisher.connected:
|
|
|
|
|
logger.warning(f"⚠️ Training publisher not connected for job {job_id}, attempting to connect...")
|
|
|
|
|
success = await training_publisher.connect()
|
|
|
|
|
if not success:
|
|
|
|
|
logger.error(f"❌ Failed to connect training_publisher for job {job_id}")
|
|
|
|
|
return
|
|
|
|
|
|
2025-07-31 15:34:35 +02:00
|
|
|
# Subscribe to training events
|
2025-08-01 17:55:14 +02:00
|
|
|
logger.info(f"🔗 Subscribing to training events for job {job_id}")
|
|
|
|
|
success = await training_publisher.consume_events(
|
2025-07-31 15:34:35 +02:00
|
|
|
exchange_name="training.events",
|
|
|
|
|
queue_name=queue_name,
|
|
|
|
|
routing_key="training.*", # Listen to all training events
|
|
|
|
|
callback=handle_training_message
|
|
|
|
|
)
|
|
|
|
|
|
2025-08-01 17:55:14 +02:00
|
|
|
if success:
|
|
|
|
|
logger.info(f"✅ Successfully set up RabbitMQ consumer for job {job_id} (queue: {queue_name})")
|
|
|
|
|
|
|
|
|
|
# Keep the consumer running indefinitely until cancelled
|
|
|
|
|
try:
|
|
|
|
|
while True:
|
|
|
|
|
await asyncio.sleep(10) # Keep consumer alive
|
|
|
|
|
logger.debug(f"🔄 Consumer heartbeat for job {job_id}")
|
|
|
|
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
logger.info(f"🛑 Consumer cancelled for job {job_id}")
|
|
|
|
|
raise
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"💥 Consumer error for job {job_id}: {e}")
|
|
|
|
|
raise
|
|
|
|
|
else:
|
|
|
|
|
logger.error(f"❌ Failed to set up RabbitMQ consumer for job {job_id}")
|
|
|
|
|
|
2025-07-31 15:34:35 +02:00
|
|
|
except Exception as e:
|
2025-08-01 17:55:14 +02:00
|
|
|
logger.error(f"💥 Exception in setup_rabbitmq_consumer_for_job for job {job_id}: {e}")
|
|
|
|
|
import traceback
|
|
|
|
|
logger.error(f"🔥 Traceback: {traceback.format_exc()}")
|
|
|
|
|
|
2025-07-31 15:34:35 +02:00
|
|
|
|
|
|
|
|
def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str:
|
|
|
|
|
"""Map RabbitMQ event types to WebSocket message types"""
|
|
|
|
|
mapping = {
|
|
|
|
|
"training.started": "started",
|
|
|
|
|
"training.progress": "progress",
|
2025-08-01 17:55:14 +02:00
|
|
|
"training.completed": "completed", # This is the key completion event
|
|
|
|
|
"training.failed": "failed", # This is also a completion event
|
2025-07-31 15:34:35 +02:00
|
|
|
"training.cancelled": "cancelled",
|
|
|
|
|
"training.step.completed": "step_completed",
|
|
|
|
|
"training.product.started": "product_started",
|
|
|
|
|
"training.product.completed": "product_completed",
|
|
|
|
|
"training.product.failed": "product_failed",
|
|
|
|
|
"training.model.trained": "model_trained",
|
|
|
|
|
"training.data.validation.started": "validation_started",
|
|
|
|
|
"training.data.validation.completed": "validation_completed"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return mapping.get(rabbitmq_event_type, "unknown")
|
|
|
|
|
|
|
|
|
|
async def get_current_job_status(job_id: str, tenant_id: str) -> Dict[str, Any]:
|
|
|
|
|
"""Get current job status from database or cache"""
|
|
|
|
|
try:
|
|
|
|
|
# This should query your database for current job status
|
|
|
|
|
# For now, return a placeholder - implement based on your database schema
|
|
|
|
|
|
|
|
|
|
from app.core.database import get_db_session
|
|
|
|
|
from app.models.training import ModelTrainingLog # Assuming you have this model
|
|
|
|
|
|
2025-08-01 16:26:36 +02:00
|
|
|
# async with get_background_db_session() as db:
|
2025-07-31 15:34:35 +02:00
|
|
|
# Query your training job status
|
|
|
|
|
# This is a placeholder - adjust based on your actual database models
|
2025-08-01 16:26:36 +02:00
|
|
|
# pass
|
2025-07-31 15:34:35 +02:00
|
|
|
|
|
|
|
|
# Placeholder return - replace with actual database query
|
|
|
|
|
return {
|
|
|
|
|
"job_id": job_id,
|
|
|
|
|
"status": "running", # or "completed", "failed", etc.
|
|
|
|
|
"progress": 0,
|
|
|
|
|
"current_step": "Starting...",
|
|
|
|
|
"started_at": "2025-07-30T19:00:00Z"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to get current job status: {e}")
|
|
|
|
|
return None
|