Websocket fix 1

This commit is contained in:
Urtzi Alfaro
2025-08-01 17:55:14 +02:00
parent 2f6f13bfef
commit 81e7ab7432
2 changed files with 573 additions and 481 deletions

View File

@@ -5,19 +5,19 @@ WebSocket endpoints for real-time training progress updates
import json import json
import asyncio import asyncio
import logging
from typing import Dict, Any from typing import Dict, Any
from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
import structlog
logger = structlog.get_logger(__name__)
from app.services.messaging import training_publisher from app.services.messaging import training_publisher
from shared.auth.decorators import ( from shared.auth.decorators import (
get_current_user_dep, get_current_user_dep,
get_current_tenant_id_dep get_current_tenant_id_dep
) )
logger = logging.getLogger(__name__)
# Create WebSocket router # Create WebSocket router
websocket_router = APIRouter() websocket_router = APIRouter()
@@ -48,8 +48,9 @@ class ConnectionManager:
logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}") logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}")
async def send_to_job(self, job_id: str, message: dict): async def send_to_job(self, job_id: str, message: dict):
"""Send message to all connections for a specific job""" """Send message to all connections for a specific job with better error handling"""
if job_id not in self.active_connections: if job_id not in self.active_connections:
logger.debug(f"No active connections for job {job_id}")
return return
# Send to all connections for this job # Send to all connections for this job
@@ -58,6 +59,7 @@ class ConnectionManager:
for connection_id, websocket in self.active_connections[job_id].items(): for connection_id, websocket in self.active_connections[job_id].items():
try: try:
await websocket.send_json(message) await websocket.send_json(message)
logger.debug(f"📤 Sent {message.get('type', 'unknown')} to connection {connection_id}")
except Exception as e: except Exception as e:
logger.warning(f"Failed to send message to connection {connection_id}: {e}") logger.warning(f"Failed to send message to connection {connection_id}: {e}")
disconnected_connections.append(connection_id) disconnected_connections.append(connection_id)
@@ -65,6 +67,11 @@ class ConnectionManager:
# Clean up disconnected connections # Clean up disconnected connections
for connection_id in disconnected_connections: for connection_id in disconnected_connections:
self.disconnect(job_id, connection_id) self.disconnect(job_id, connection_id)
# Log successful sends
active_count = len(self.active_connections.get(job_id, {}))
if active_count > 0:
logger.info(f"📡 Sent {message.get('type', 'unknown')} message to {active_count} connection(s) for job {job_id}")
# Global connection manager # Global connection manager
connection_manager = ConnectionManager() connection_manager = ConnectionManager()
@@ -76,40 +83,31 @@ async def training_progress_websocket(
job_id: str job_id: str
): ):
""" """
WebSocket endpoint for real-time training progress updates FIXED WebSocket endpoint for real-time training progress updates
Prevents premature disconnection during training
Message format sent to client:
{
"type": "progress" | "completed" | "failed" | "step_completed",
"job_id": "job_123",
"timestamp": "2025-07-30T19:08:53Z",
"data": {
"progress": 45,
"current_step": "Training model for bread",
"current_product": "bread",
"products_completed": 2,
"products_total": 5,
"estimated_time_remaining_minutes": 8
}
}
""" """
connection_id = f"{tenant_id}_{id(websocket)}" connection_id = f"{tenant_id}_{id(websocket)}"
# Accept connection # Accept connection
await connection_manager.connect(websocket, job_id, connection_id) await connection_manager.connect(websocket, job_id, connection_id)
logger.info(f"WebSocket connection established for job {job_id}")
# Set up RabbitMQ consumer for this job # Set up RabbitMQ consumer for this job
consumer_task = None consumer_task = None
training_completed = False # Track training completion
try: try:
# Start RabbitMQ consumer # Start RabbitMQ consumer
logger.info(f"Setting up RabbitMQ consumer for job {job_id}")
consumer_task = asyncio.create_task( consumer_task = asyncio.create_task(
setup_rabbitmq_consumer_for_job(job_id, tenant_id) setup_rabbitmq_consumer_for_job(job_id, tenant_id)
) )
# Give consumer time to set up
await asyncio.sleep(0.5)
# Send initial status if available # Send initial status if available
try: try:
# You can fetch current job status from database here
initial_status = await get_current_job_status(job_id, tenant_id) initial_status = await get_current_job_status(job_id, tenant_id)
if initial_status: if initial_status:
await websocket.send_json({ await websocket.send_json({
@@ -117,48 +115,92 @@ async def training_progress_websocket(
"job_id": job_id, "job_id": job_id,
"data": initial_status "data": initial_status
}) })
logger.info(f"Sent initial status for job {job_id}")
except Exception as e: except Exception as e:
logger.warning(f"Failed to send initial status: {e}") logger.warning(f"Failed to send initial status: {e}")
# Keep connection alive and handle client messages # Keep connection alive - IMPROVED ERROR HANDLING
while True: last_activity = asyncio.get_event_loop().time()
while not training_completed:
try: try:
# Wait for client ping or other messages # Wait for client messages with timeout
message = await websocket.receive_text() try:
message = await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
if message == "ping": last_activity = asyncio.get_event_loop().time()
await websocket.send_text("pong")
elif message == "get_status": if message == "ping":
# Send current status on demand await websocket.send_text("pong")
current_status = await get_current_job_status(job_id, tenant_id) logger.debug(f"Ping received from job {job_id}")
if current_status: elif message == "get_status":
current_status = await get_current_job_status(job_id, tenant_id)
if current_status:
await websocket.send_json({
"type": "current_status",
"job_id": job_id,
"data": current_status
})
elif message == "close":
logger.info(f"Client requested connection close for job {job_id}")
break
except asyncio.TimeoutError:
# No message received in 30 seconds - send heartbeat
current_time = asyncio.get_event_loop().time()
if current_time - last_activity > 60: # 60 seconds of inactivity
logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat")
try:
await websocket.send_json({ await websocket.send_json({
"type": "current_status", "type": "heartbeat",
"job_id": job_id, "job_id": job_id,
"data": current_status "timestamp": datetime.utcnow().isoformat()
}) })
except Exception as e:
logger.error(f"Failed to send heartbeat for job {job_id}: {e}")
break
except WebSocketDisconnect: except WebSocketDisconnect:
logger.info(f"WebSocket disconnected for job {job_id}") logger.info(f"WebSocket client disconnected for job {job_id}")
break break
except Exception as e: except Exception as e:
logger.error(f"WebSocket error for job {job_id}: {e}") logger.error(f"WebSocket error for job {job_id}: {e}")
break # Don't break immediately - try to recover
await asyncio.sleep(1)
logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}")
except Exception as e:
logger.error(f"Critical WebSocket error for job {job_id}: {e}")
finally: finally:
# Clean up # IMPROVED CLEANUP - Don't cancel consumer unless truly disconnecting
logger.info(f"Cleaning up WebSocket connection for job {job_id}")
connection_manager.disconnect(job_id, connection_id) connection_manager.disconnect(job_id, connection_id)
# Only cancel consumer if we're truly done (not just a temporary error)
if consumer_task and not consumer_task.done(): if consumer_task and not consumer_task.done():
consumer_task.cancel() if training_completed:
logger.info(f"Training completed, cancelling consumer for job {job_id}")
consumer_task.cancel()
else:
logger.warning(f"WebSocket disconnected but training not completed for job {job_id}")
# Let the consumer continue running for other potential connections
# Don't cancel it unless we're sure the job is done
try: try:
await consumer_task await consumer_task
except asyncio.CancelledError: except asyncio.CancelledError:
pass logger.info(f"Consumer task cancelled for job {job_id}")
except Exception as e:
logger.error(f"Consumer task error for job {job_id}: {e}")
async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str): async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
"""Set up RabbitMQ consumer to listen for training events for a specific job""" """Set up RabbitMQ consumer to listen for training events for a specific job"""
logger.info(f"🚀 Setting up RabbitMQ consumer for job {job_id}")
try: try:
# Create a unique queue for this WebSocket connection # Create a unique queue for this WebSocket connection
queue_name = f"websocket_training_{job_id}_{tenant_id}" queue_name = f"websocket_training_{job_id}_{tenant_id}"
@@ -170,12 +212,16 @@ async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
body = message.body.decode() body = message.body.decode()
data = json.loads(body) data = json.loads(body)
logger.debug(f"🔍 Received message for job {job_id}: {data.get('event_type', 'unknown')}")
# Extract event data # Extract event data
event_type = data.get("event_type", "unknown") event_type = data.get("event_type", "unknown")
event_data = data.get("data", {}) event_data = data.get("data", {})
# Only process messages for this specific job # Only process messages for this specific job
if event_data.get("job_id") != job_id: message_job_id = event_data.get("job_id") if event_data else None
if message_job_id != job_id:
logger.debug(f"⏭️ Ignoring message for different job: {message_job_id}")
await message.ack() await message.ack()
return return
@@ -187,36 +233,76 @@ async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str):
"data": event_data "data": event_data
} }
logger.info(f"📤 Forwarding {event_type} message to WebSocket clients for job {job_id}")
# Send to all WebSocket connections for this job # Send to all WebSocket connections for this job
await connection_manager.send_to_job(job_id, websocket_message) await connection_manager.send_to_job(job_id, websocket_message)
# Check if this is a completion message
if event_type in ["training.completed", "training.failed"]:
logger.info(f"🎯 Training completion detected for job {job_id}: {event_type}")
# Mark training as completed (you might want to store this in a global state)
# For now, we'll let the WebSocket handle this through the message
# Acknowledge the message # Acknowledge the message
await message.ack() await message.ack()
logger.debug(f"Forwarded training event to WebSocket: {event_type}") logger.debug(f"✅ Successfully processed {event_type} for job {job_id}")
except Exception as e: except Exception as e:
logger.error(f"Error handling training message for WebSocket: {e}") logger.error(f"Error handling training message for job {job_id}: {e}")
import traceback
logger.error(f"💥 Traceback: {traceback.format_exc()}")
await message.nack(requeue=False) await message.nack(requeue=False)
# Check if training_publisher is connected
if not training_publisher.connected:
logger.warning(f"⚠️ Training publisher not connected for job {job_id}, attempting to connect...")
success = await training_publisher.connect()
if not success:
logger.error(f"❌ Failed to connect training_publisher for job {job_id}")
return
# Subscribe to training events # Subscribe to training events
await training_publisher.consume_events( logger.info(f"🔗 Subscribing to training events for job {job_id}")
success = await training_publisher.consume_events(
exchange_name="training.events", exchange_name="training.events",
queue_name=queue_name, queue_name=queue_name,
routing_key="training.*", # Listen to all training events routing_key="training.*", # Listen to all training events
callback=handle_training_message callback=handle_training_message
) )
if success:
logger.info(f"✅ Successfully set up RabbitMQ consumer for job {job_id} (queue: {queue_name})")
# Keep the consumer running indefinitely until cancelled
try:
while True:
await asyncio.sleep(10) # Keep consumer alive
logger.debug(f"🔄 Consumer heartbeat for job {job_id}")
except asyncio.CancelledError:
logger.info(f"🛑 Consumer cancelled for job {job_id}")
raise
except Exception as e:
logger.error(f"💥 Consumer error for job {job_id}: {e}")
raise
else:
logger.error(f"❌ Failed to set up RabbitMQ consumer for job {job_id}")
except Exception as e: except Exception as e:
logger.error(f"Failed to set up RabbitMQ consumer for WebSocket: {e}") logger.error(f"💥 Exception in setup_rabbitmq_consumer_for_job for job {job_id}: {e}")
import traceback
logger.error(f"🔥 Traceback: {traceback.format_exc()}")
def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str: def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str:
"""Map RabbitMQ event types to WebSocket message types""" """Map RabbitMQ event types to WebSocket message types"""
mapping = { mapping = {
"training.started": "started", "training.started": "started",
"training.progress": "progress", "training.progress": "progress",
"training.completed": "completed", "training.completed": "completed", # This is the key completion event
"training.failed": "failed", "training.failed": "failed", # This is also a completion event
"training.cancelled": "cancelled", "training.cancelled": "cancelled",
"training.step.completed": "step_completed", "training.step.completed": "step_completed",
"training.product.started": "product_started", "training.product.started": "product_started",

View File

@@ -13,7 +13,7 @@ TEST_PASSWORD="TestPassword123!"
TEST_NAME="Test Bakery Owner" TEST_NAME="Test Bakery Owner"
REAL_CSV_FILE="bakery_sales_2023_2024.csv" REAL_CSV_FILE="bakery_sales_2023_2024.csv"
WS_BASE="ws://localhost:8002/api/v1/ws" WS_BASE="ws://localhost:8002/api/v1/ws"
WS_TEST_DURATION=200 # seconds to listen for WebSocket messages WS_TEST_DURATION=2000 # seconds to listen for WebSocket messages
WS_PID="" WS_PID=""
@@ -158,20 +158,429 @@ check_timezone_error() {
return 1 # No timezone error return 1 # No timezone error
} }
# Function to test WebSocket connection using websocat (if available) or Node.js test_websocket_with_nodejs_builtin() {
local tenant_id="$1"
local job_id="$2"
local max_duration="$3" # Maximum time to wait (fallback)
echo "Using Node.js with built-in modules for WebSocket testing..."
echo "Will monitor until job completion or ${max_duration}s timeout"
# Create ENHANCED Node.js WebSocket test script
local ws_test_script="/tmp/websocket_test_$job_id.js"
cat > "$ws_test_script" << 'EOF'
// ENHANCED WebSocket test - waits for job completion
const https = require('https');
const http = require('http');
const crypto = require('crypto');
const tenantId = process.argv[2];
const jobId = process.argv[3];
const maxDuration = parseInt(process.argv[4]) * 1000; // Convert to milliseconds
const accessToken = process.argv[5];
const wsUrl = process.argv[6];
console.log(`🚀 Starting enhanced WebSocket monitoring`);
console.log(`Connecting to: ${wsUrl}`);
console.log(`Will wait for job completion (max ${maxDuration/1000}s)`);
// Parse WebSocket URL
const url = new URL(wsUrl);
const isSecure = url.protocol === 'wss:';
const port = url.port || (isSecure ? 443 : 80);
// Create WebSocket key
const key = crypto.randomBytes(16).toString('base64');
// WebSocket handshake headers
const headers = {
'Upgrade': 'websocket',
'Connection': 'Upgrade',
'Sec-WebSocket-Key': key,
'Sec-WebSocket-Version': '13',
'Authorization': `Bearer ${accessToken}`
};
const options = {
hostname: url.hostname,
port: port,
path: url.pathname,
method: 'GET',
headers: headers
};
console.log(`Attempting WebSocket handshake to ${url.hostname}:${port}${url.pathname}`);
const client = isSecure ? https : http;
let messageCount = 0;
let jobCompleted = false;
let lastProgressUpdate = Date.now();
let highestProgress = 0;
// Enhanced job tracking
const jobStats = {
startTime: Date.now(),
progressUpdates: 0,
stepsCompleted: [],
productsProcessed: [],
errors: []
};
const req = client.request(options);
req.on('upgrade', (res, socket, head) => {
console.log('✅ WebSocket handshake successful');
console.log('📡 Monitoring training progress...\n');
let buffer = Buffer.alloc(0);
socket.on('data', (data) => {
buffer = Buffer.concat([buffer, data]);
// WebSocket frame parsing
while (buffer.length >= 2) {
const firstByte = buffer[0];
const secondByte = buffer[1];
const fin = (firstByte & 0x80) === 0x80;
const opcode = firstByte & 0x0F;
const masked = (secondByte & 0x80) === 0x80;
let payloadLength = secondByte & 0x7F;
let offset = 2;
// Handle extended payload length
if (payloadLength === 126) {
if (buffer.length < offset + 2) break;
payloadLength = buffer.readUInt16BE(offset);
offset += 2;
} else if (payloadLength === 127) {
if (buffer.length < offset + 8) break;
const high = buffer.readUInt32BE(offset);
const low = buffer.readUInt32BE(offset + 4);
if (high !== 0) {
console.log('⚠️ Large payload detected, skipping...');
buffer = buffer.slice(offset + 8);
continue;
}
payloadLength = low;
offset += 8;
}
// Check if we have the complete frame
if (buffer.length < offset + payloadLength) {
break; // Wait for more data
}
// Extract payload
const payload = buffer.slice(offset, offset + payloadLength);
buffer = buffer.slice(offset + payloadLength);
// Handle different frame types
if (opcode === 1 && fin) { // Text frame
messageCount++;
lastProgressUpdate = Date.now();
const timestamp = new Date().toLocaleTimeString();
try {
const messageText = payload.toString('utf8');
const message = JSON.parse(messageText);
// Enhanced message processing
processTrainingMessage(message, timestamp);
} catch (e) {
const rawText = payload.toString('utf8');
console.log(`[${timestamp}] ⚠️ Raw message: ${rawText.substring(0, 200)}${rawText.length > 200 ? '...' : ''}`);
}
} else if (opcode === 8) { // Close frame
console.log('🔌 WebSocket closed by server');
socket.end();
return;
} else if (opcode === 9) { // Ping frame
// Send pong response
const pongFrame = Buffer.concat([
Buffer.from([0x8A, payload.length]),
payload
]);
socket.write(pongFrame);
} else if (opcode === 10) { // Pong frame
// Ignore pong responses
continue;
}
}
});
// Enhanced message processing function
function processTrainingMessage(message, timestamp) {
const messageType = message.type || 'unknown';
const data = message.data || {};
console.log(`[${timestamp}] 📨 Message ${messageCount}: ${messageType.toUpperCase()}`);
// Track job statistics
if (messageType === 'progress') {
jobStats.progressUpdates++;
const progress = data.progress || 0;
const step = data.current_step || 'Unknown step';
const product = data.current_product;
// Update highest progress
if (progress > highestProgress) {
highestProgress = progress;
}
// Track steps
if (step && !jobStats.stepsCompleted.includes(step)) {
jobStats.stepsCompleted.push(step);
}
// Track products
if (product && !jobStats.productsProcessed.includes(product)) {
jobStats.productsProcessed.push(product);
}
// Display progress with enhanced formatting
console.log(` 📊 Progress: ${progress}% (${step})`);
if (product) {
console.log(` 🍞 Product: ${product}`);
}
if (data.products_completed && data.products_total) {
console.log(` 📦 Products: ${data.products_completed}/${data.products_total} completed`);
}
if (data.estimated_time_remaining_minutes) {
console.log(` ⏱️ ETA: ${data.estimated_time_remaining_minutes} minutes`);
}
} else if (messageType === 'completed') {
jobCompleted = true;
const duration = Math.round((Date.now() - jobStats.startTime) / 1000);
console.log(`\n🎉 TRAINING COMPLETED SUCCESSFULLY!`);
console.log(` ⏱️ Total Duration: ${duration}s`);
if (data.results) {
const results = data.results;
if (results.successful_trainings !== undefined) {
console.log(` ✅ Models Trained: ${results.successful_trainings}`);
}
if (results.total_products !== undefined) {
console.log(` 📦 Total Products: ${results.total_products}`);
}
if (results.success_rate !== undefined) {
console.log(` 📈 Success Rate: ${results.success_rate}%`);
}
}
// Close connection after completion
setTimeout(() => {
console.log('\n📊 Training job completed - closing WebSocket connection');
socket.end();
}, 2000); // Wait 2 seconds to ensure all final messages are received
} else if (messageType === 'failed') {
jobCompleted = true;
jobStats.errors.push(data);
console.log(`\n❌ TRAINING FAILED!`);
if (data.error) {
console.log(` 💥 Error: ${data.error}`);
}
if (data.error_details) {
console.log(` 📝 Details: ${JSON.stringify(data.error_details, null, 2)}`);
}
// Close connection after failure
setTimeout(() => {
console.log('\n📊 Training job failed - closing WebSocket connection');
socket.end();
}, 2000);
} else if (messageType === 'step_completed') {
console.log(` ✅ Step completed: ${data.step_name || 'Unknown'}`);
} else if (messageType === 'product_started') {
console.log(` 🚀 Started training: ${data.product_name || 'Unknown product'}`);
} else if (messageType === 'product_completed') {
console.log(` ✅ Product completed: ${data.product_name || 'Unknown product'}`);
if (data.metrics) {
console.log(` 📊 Metrics: ${JSON.stringify(data.metrics, null, 2)}`);
}
}
console.log(''); // Add spacing between messages
}
socket.on('end', () => {
const duration = Math.round((Date.now() - jobStats.startTime) / 1000);
console.log(`\n📊 WebSocket connection ended`);
console.log(`📨 Total messages received: ${messageCount}`);
console.log(`⏱️ Connection duration: ${duration}s`);
console.log(`📈 Highest progress reached: ${highestProgress}%`);
if (jobCompleted) {
console.log('✅ Job completed successfully - connection closed normally');
process.exit(0);
} else {
console.log('⚠️ Connection ended before job completion');
console.log(`📊 Progress reached: ${highestProgress}%`);
console.log(`📋 Steps completed: ${jobStats.stepsCompleted.length}`);
process.exit(1);
}
});
socket.on('error', (error) => {
console.log(`❌ WebSocket error: ${error.message}`);
process.exit(1);
});
// Enhanced ping mechanism - send pings more frequently
const pingInterval = setInterval(() => {
if (socket.writable && !jobCompleted) {
try {
const pingFrame = Buffer.from([0x89, 0x00]);
socket.write(pingFrame);
} catch (e) {
// Ignore ping errors
}
}
}, 5000); // Ping every 5 seconds
// Heartbeat check - ensure we're still receiving messages
const heartbeatInterval = setInterval(() => {
if (!jobCompleted) {
const timeSinceLastMessage = Date.now() - lastProgressUpdate;
if (timeSinceLastMessage > 60000) { // 60 seconds without messages
console.log('\n⚠ No messages received for 60 seconds');
console.log(' This could indicate the training is stuck or connection issues');
console.log(` Last progress: ${highestProgress}%`);
} else if (timeSinceLastMessage > 30000) { // 30 seconds warning
console.log(`\n💤 Quiet period: ${Math.round(timeSinceLastMessage/1000)}s since last update`);
console.log(' (This is normal during intensive training phases)');
}
}
}, 15000); // Check every 15 seconds
// Safety timeout - close connection if max duration exceeded
const safetyTimeout = setTimeout(() => {
if (!jobCompleted) {
clearInterval(pingInterval);
clearInterval(heartbeatInterval);
console.log(`\n⏰ Maximum duration (${maxDuration/1000}s) reached`);
console.log(`📊 Final status:`);
console.log(` 📨 Messages received: ${messageCount}`);
console.log(` 📈 Progress reached: ${highestProgress}%`);
console.log(` 📋 Steps completed: ${jobStats.stepsCompleted.length}`);
console.log(` 🍞 Products processed: ${jobStats.productsProcessed.length}`);
if (messageCount > 0) {
console.log('\n✅ WebSocket communication was successful!');
console.log(' Training may still be running - check server logs for completion');
} else {
console.log('\n⚠ No messages received during monitoring period');
}
socket.end();
}
}, maxDuration);
// Clean up intervals when job completes
socket.on('end', () => {
clearInterval(pingInterval);
clearInterval(heartbeatInterval);
clearTimeout(safetyTimeout);
});
});
req.on('response', (res) => {
console.log(`❌ HTTP response instead of WebSocket upgrade: ${res.statusCode}`);
console.log('Response headers:', res.headers);
let body = '';
res.on('data', chunk => body += chunk);
res.on('end', () => {
if (body) console.log('Response body:', body);
process.exit(1);
});
});
req.on('error', (error) => {
console.log(`❌ Connection error: ${error.message}`);
process.exit(1);
});
req.end();
EOF
# Run the ENHANCED Node.js WebSocket test
local ws_url="$WS_BASE/tenants/$tenant_id/training/jobs/$job_id/live"
echo "Starting enhanced WebSocket monitoring..."
node "$ws_test_script" "$tenant_id" "$job_id" "$max_duration" "$ACCESS_TOKEN" "$ws_url"
local exit_code=$?
# Clean up
rm -f "$ws_test_script"
if [ $exit_code -eq 0 ]; then
log_success "Training job completed successfully!"
echo " 📡 WebSocket monitoring detected job completion"
echo " 🎉 Real-time progress tracking worked perfectly"
else
log_warning "WebSocket monitoring ended before job completion"
echo " 📊 Check the progress logs above for details"
fi
return $exit_code
}
install_websocat_if_needed() {
if ! command -v websocat >/dev/null 2>&1; then
echo "📦 Installing websocat for better WebSocket testing..."
# Try to install websocat (works on most Linux systems)
if command -v cargo >/dev/null 2>&1; then
cargo install websocat 2>/dev/null || true
elif [ -x "$(command -v wget)" ]; then
wget -q -O /tmp/websocat "https://github.com/vi/websocat/releases/latest/download/websocat.x86_64-unknown-linux-musl" 2>/dev/null || true
if [ -f /tmp/websocat ]; then
chmod +x /tmp/websocat
sudo mv /tmp/websocat /usr/local/bin/ 2>/dev/null || mv /tmp/websocat ~/bin/ 2>/dev/null || true
fi
fi
if command -v websocat >/dev/null 2>&1; then
log_success "websocat installed successfully"
return 0
else
log_warning "websocat installation failed, using Node.js fallback"
return 1
fi
fi
return 0
}
# IMPROVED: WebSocket connection function with better tool selection
test_websocket_connection() { test_websocket_connection() {
local tenant_id="$1" local tenant_id="$1"
local job_id="$2" local job_id="$2"
local duration="$3" local duration="$3"
log_step "4.2. Testing WebSocket connection for real-time training progress" log_step "4.2. Connecting to WebSocket for real-time progress monitoring"
echo "WebSocket URL: $WS_BASE/tenants/$tenant_id/training/jobs/$job_id/live" echo "WebSocket URL: $WS_BASE/tenants/$tenant_id/training/jobs/$job_id/live"
echo "Test duration: ${duration}s" echo "Test duration: ${duration}s"
echo "" echo ""
# Check if websocat is available # Try to install websocat if not available
if command -v websocat >/dev/null 2>&1; then if install_websocat_if_needed; then
test_websocket_with_websocat "$tenant_id" "$job_id" "$duration" test_websocket_with_websocat "$tenant_id" "$job_id" "$duration"
elif command -v node >/dev/null 2>&1; then elif command -v node >/dev/null 2>&1; then
test_websocket_with_nodejs_builtin "$tenant_id" "$job_id" "$duration" test_websocket_with_nodejs_builtin "$tenant_id" "$job_id" "$duration"
@@ -221,382 +630,6 @@ test_websocket_with_websocat() {
fi fi
} }
# Test WebSocket using Node.js
test_websocket_with_nodejs() {
local tenant_id="$1"
local job_id="$2"
local duration="$3"
echo "Using Node.js for WebSocket testing..."
# Create Node.js WebSocket test script
local ws_test_script="/tmp/websocket_test_$job_id.js"
cat > "$ws_test_script" << 'EOF'
const WebSocket = require('ws');
const tenantId = process.argv[2];
const jobId = process.argv[3];
const duration = parseInt(process.argv[4]) * 1000;
const accessToken = process.argv[5];
const wsUrl = process.argv[6];
console.log(`Connecting to: ${wsUrl}`);
const ws = new WebSocket(wsUrl, {
headers: {
'Authorization': `Bearer ${accessToken}`
}
});
let messageCount = 0;
let startTime = Date.now();
ws.on('open', function() {
console.log('✅ WebSocket connected successfully');
// Send periodic pings
const pingInterval = setInterval(() => {
if (ws.readyState === WebSocket.OPEN) {
ws.send('ping');
}
}, 5000);
// Close after duration
setTimeout(() => {
clearInterval(pingInterval);
console.log(`\n📊 WebSocket test completed after ${duration/1000}s`);
console.log(`📨 Total messages received: ${messageCount}`);
if (messageCount > 0) {
console.log('✅ WebSocket communication successful');
} else {
console.log('⚠️ No training progress messages received');
console.log(' This may be normal if training completed quickly');
}
ws.close();
process.exit(0);
}, duration);
});
ws.on('message', function(data) {
messageCount++;
const timestamp = new Date().toLocaleTimeString();
try {
const message = JSON.parse(data);
console.log(`\n[${timestamp}] 📨 Message ${messageCount}:`);
console.log(` Type: ${message.type || 'unknown'}`);
console.log(` Job ID: ${message.job_id || 'unknown'}`);
if (message.data) {
if (message.data.progress !== undefined) {
console.log(` Progress: ${message.data.progress}%`);
}
if (message.data.current_step) {
console.log(` Step: ${message.data.current_step}`);
}
if (message.data.current_product) {
console.log(` Product: ${message.data.current_product}`);
}
if (message.data.estimated_time_remaining_minutes) {
console.log(` ETA: ${message.data.estimated_time_remaining_minutes} minutes`);
}
}
// Special handling for completion messages
if (message.type === 'completed') {
console.log('🎉 Training completed!');
} else if (message.type === 'failed') {
console.log('❌ Training failed!');
}
} catch (e) {
console.log(`[${timestamp}] Raw message: ${data}`);
}
});
ws.on('error', function(error) {
console.log('❌ WebSocket error:', error.message);
});
ws.on('close', function(code, reason) {
console.log(`\n🔌 WebSocket closed (code: ${code}, reason: ${reason || 'normal'})`);
process.exit(code === 1000 ? 0 : 1);
});
EOF
# Run Node.js WebSocket test
local ws_url="$WS_BASE/tenants/$tenant_id/training/jobs/$job_id/live"
node "$ws_test_script" "$tenant_id" "$job_id" "$duration" "$ACCESS_TOKEN" "$ws_url" &
WS_PID=$!
# Wait for completion
wait $WS_PID
local exit_code=$?
# Clean up
rm -f "$ws_test_script"
if [ $exit_code -eq 0 ]; then
log_success "WebSocket test completed successfully"
else
log_warning "WebSocket test completed with issues"
fi
}
test_websocket_with_nodejs_builtin() {
local tenant_id="$1"
local job_id="$2"
local duration="$3"
echo "Using Node.js with built-in modules for WebSocket testing..."
# Create Node.js WebSocket test script using built-in modules only
local ws_test_script="/tmp/websocket_test_$job_id.js"
cat > "$ws_test_script" << 'EOF'
// WebSocket test using only built-in Node.js modules
const { WebSocket } = require('node:http');
const https = require('https');
const http = require('http');
const crypto = require('crypto');
const tenantId = process.argv[2];
const jobId = process.argv[3];
const duration = parseInt(process.argv[4]) * 1000;
const accessToken = process.argv[5];
const wsUrl = process.argv[6];
console.log(`Connecting to: ${wsUrl}`);
console.log(`Duration: ${duration/1000}s`);
// Parse WebSocket URL
const url = new URL(wsUrl);
const isSecure = url.protocol === 'wss:';
const port = url.port || (isSecure ? 443 : 80);
// Create WebSocket key (required for WebSocket handshake)
const key = crypto.randomBytes(16).toString('base64');
// WebSocket handshake headers
const headers = {
'Upgrade': 'websocket',
'Connection': 'Upgrade',
'Sec-WebSocket-Key': key,
'Sec-WebSocket-Version': '13',
'Authorization': `Bearer ${accessToken}`
};
const options = {
hostname: url.hostname,
port: port,
path: url.pathname,
method: 'GET',
headers: headers
};
console.log(`Attempting WebSocket handshake to ${url.hostname}:${port}${url.pathname}`);
const client = isSecure ? https : http;
let messageCount = 0;
let startTime = Date.now();
const req = client.request(options);
req.on('upgrade', (res, socket, head) => {
console.log('✅ WebSocket handshake successful');
let buffer = '';
socket.on('data', (data) => {
buffer += data.toString();
// Process complete WebSocket frames
while (buffer.length > 0) {
// Simple WebSocket frame parsing (for text frames)
if (buffer.length < 2) break;
const firstByte = buffer.charCodeAt(0);
const secondByte = buffer.charCodeAt(1);
const opcode = firstByte & 0x0F;
const masked = (secondByte & 0x80) === 0x80;
let payloadLength = secondByte & 0x7F;
let offset = 2;
if (payloadLength === 126) {
if (buffer.length < offset + 2) break;
payloadLength = (buffer.charCodeAt(offset) << 8) | buffer.charCodeAt(offset + 1);
offset += 2;
} else if (payloadLength === 127) {
if (buffer.length < offset + 8) break;
// For simplicity, assume payload length fits in 32 bits
payloadLength = (buffer.charCodeAt(offset + 4) << 24) |
(buffer.charCodeAt(offset + 5) << 16) |
(buffer.charCodeAt(offset + 6) << 8) |
buffer.charCodeAt(offset + 7);
offset += 8;
}
if (buffer.length < offset + payloadLength) break;
// Extract payload
let payload = buffer.slice(offset, offset + payloadLength);
buffer = buffer.slice(offset + payloadLength);
if (opcode === 1) { // Text frame
messageCount++;
const timestamp = new Date().toLocaleTimeString();
try {
const message = JSON.parse(payload);
console.log(`\n[${timestamp}] 📨 Message ${messageCount}:`);
console.log(` Type: ${message.type || 'unknown'}`);
console.log(` Job ID: ${message.job_id || 'unknown'}`);
if (message.data) {
if (message.data.progress !== undefined) {
console.log(` Progress: ${message.data.progress}%`);
}
if (message.data.current_step) {
console.log(` Step: ${message.data.current_step}`);
}
}
if (message.type === 'completed') {
console.log('🎉 Training completed!');
} else if (message.type === 'failed') {
console.log('❌ Training failed!');
}
} catch (e) {
console.log(`[${timestamp}] Raw message: ${payload}`);
}
} else if (opcode === 8) { // Close frame
console.log('🔌 WebSocket closed by server');
socket.end();
return;
}
}
});
socket.on('end', () => {
console.log(`\n📊 WebSocket test completed`);
console.log(`📨 Total messages received: ${messageCount}`);
if (messageCount > 0) {
console.log('✅ WebSocket communication successful');
} else {
console.log('⚠️ No messages received during test period');
}
process.exit(0);
});
socket.on('error', (error) => {
console.log('❌ WebSocket error:', error.message);
process.exit(1);
});
// Send periodic pings to keep connection alive
const pingInterval = setInterval(() => {
if (socket.writable) {
// Send ping frame (opcode 9)
const pingFrame = Buffer.from([0x89, 0x00]);
socket.write(pingFrame);
}
}, 10000);
// Close after duration
setTimeout(() => {
clearInterval(pingInterval);
console.log(`\n⏰ Test duration (${duration/1000}s) completed`);
console.log(`📨 Total messages received: ${messageCount}`);
if (messageCount > 0) {
console.log('✅ WebSocket communication successful');
} else {
console.log('⚠️ No training progress messages received');
console.log(' This is normal if training completed before WebSocket connection');
}
socket.end();
process.exit(0);
}, duration);
});
req.on('response', (res) => {
console.log(`❌ HTTP response instead of WebSocket upgrade: ${res.statusCode}`);
console.log('Response headers:', res.headers);
let body = '';
res.on('data', chunk => body += chunk);
res.on('end', () => {
if (body) console.log('Response body:', body);
});
process.exit(1);
});
req.on('error', (error) => {
console.log('❌ Connection error:', error.message);
process.exit(1);
});
req.end();
EOF
# Run the Node.js WebSocket test
local ws_url="$WS_BASE/tenants/$tenant_id/training/jobs/$job_id/live"
echo "Starting WebSocket test..."
node "$ws_test_script" "$tenant_id" "$job_id" "$duration" "$ACCESS_TOKEN" "$ws_url"
local exit_code=$?
# Clean up
rm -f "$ws_test_script"
return $exit_code
}
# Fallback: Test WebSocket using curl (limited functionality)
test_websocket_with_curl() {
local tenant_id="$1"
local job_id="$2"
local duration="$3"
log_warning "WebSocket testing tools not available (websocat/node.js)"
echo "Falling back to HTTP polling simulation..."
# Create a simple HTTP-based progress polling simulation
local poll_endpoint="$API_BASE/api/v1/tenants/$tenant_id/training/jobs/$job_id/status"
local end_time=$(($(date +%s) + duration))
local poll_count=0
echo "Simulating real-time updates by polling: $poll_endpoint"
echo "Duration: ${duration}s"
while [ $(date +%s) -lt $end_time ]; do
poll_count=$((poll_count + 1))
echo ""
echo "[$(date '+%H:%M:%S')] Poll #$poll_count - Checking training status..."
STATUS_RESPONSE=$(curl -s -X GET "$poll_endpoint" \
-H "Authorization: Bearer $ACCESS_TOKEN" \
-H "X-Tenant-ID: $tenant_id")
echo "Response:"
echo "$STATUS_RESPONSE" | python3 -m json.tool 2>/dev/null || echo "$STATUS_RESPONSE"
# Check if training is complete
if echo "$STATUS_RESPONSE" | grep -q '"status".*"completed"\|"status".*"failed"'; then
log_success "Training status detected as complete/failed - stopping polling"
break
fi
sleep 5
done
log_success "HTTP polling simulation completed ($poll_count polls)"
echo "💡 For real WebSocket testing, install: npm install -g websocat"
}
# Wait for WebSocket messages and analyze them # Wait for WebSocket messages and analyze them
wait_for_websocket_messages() { wait_for_websocket_messages() {
local ws_log="$1" local ws_log="$1"
@@ -660,8 +693,8 @@ wait_for_websocket_messages() {
# Enhanced training step with WebSocket testing # Enhanced training step with WebSocket testing
enhanced_training_step_with_completion_check() { enhanced_training_step_with_completion_check() {
echo -e "${STEP_ICONS[3]} ${PURPLE}STEP 4: MODEL TRAINING WITH WEBSOCKET MONITORING${NC}" echo -e "${STEP_ICONS[3]} ${PURPLE}STEP 4: MODEL TRAINING WITH SMART WEBSOCKET MONITORING${NC}"
echo "Enhanced training step with real-time progress monitoring" echo "Enhanced training step with completion-aware progress monitoring"
echo "" echo ""
log_step "4.1. Initiating model training with FULL dataset" log_step "4.1. Initiating model training with FULL dataset"
@@ -694,42 +727,50 @@ enhanced_training_step_with_completion_check() {
echo " Job ID: $WEBSOCKET_JOB_ID" echo " Job ID: $WEBSOCKET_JOB_ID"
echo " Status: $JOB_STATUS" echo " Status: $JOB_STATUS"
# Check if training completed instantly # Determine monitoring strategy based on initial status
if [ "$JOB_STATUS" = "completed" ]; then if [ "$JOB_STATUS" = "completed" ]; then
log_warning "Training completed instantly - no real-time progress to monitor" log_warning "Training completed instantly - no real-time progress to monitor"
echo " This can happen when:" echo " This can happen when:"
echo " • Models are already trained and cached"
echo " • No valid products found in sales data" echo " • No valid products found in sales data"
echo " • Training data is insufficient" echo " • Training data is insufficient"
echo " • Models are already trained and cached"
echo ""
# Show training results # Show training results
TOTAL_PRODUCTS=$(extract_json_field "$TRAINING_RESPONSE" "training_results.total_products") TOTAL_PRODUCTS=$(extract_json_field "$TRAINING_RESPONSE" "training_results.total_products")
SUCCESSFUL_TRAININGS=$(extract_json_field "$TRAINING_RESPONSE" "training_results.successful_trainings") SUCCESSFUL_TRAININGS=$(extract_json_field "$TRAINING_RESPONSE" "training_results.successful_trainings")
SALES_RECORDS=$(extract_json_field "$TRAINING_RESPONSE" "data_summary.sales_records") SALES_RECORDS=$(extract_json_field "$TRAINING_RESPONSE" "data_summary.sales_records")
echo ""
echo "📊 Training Summary:" echo "📊 Training Summary:"
echo " Sales records: $SALES_RECORDS" echo " Sales records: $SALES_RECORDS"
echo " Products found: $TOTAL_PRODUCTS" echo " Products found: $TOTAL_PRODUCTS"
echo " Successful trainings: $SUCCESSFUL_TRAININGS" echo " Successful trainings: $SUCCESSFUL_TRAININGS"
if [ "$TOTAL_PRODUCTS" = "0" ]; then # Brief WebSocket connection test
log_warning "No products found for training"
echo " Possible causes:"
echo " • CSV doesn't contain valid product names"
echo " • Product column is missing or malformed"
echo " • Insufficient sales data per product"
fi
# Still test WebSocket for demonstration
log_step "4.2. Testing WebSocket endpoint (demonstration mode)" log_step "4.2. Testing WebSocket endpoint (demonstration mode)"
echo "Even though training is complete, testing WebSocket connection..." echo "Testing WebSocket connection for 10 seconds..."
test_websocket_with_nodejs_builtin "$TENANT_ID" "$WEBSOCKET_JOB_ID" "10" test_websocket_with_nodejs_builtin "$TENANT_ID" "$WEBSOCKET_JOB_ID" "10"
else else
# Training is in progress - monitor with WebSocket # Training is in progress - use smart monitoring
log_step "4.2. Connecting to WebSocket for real-time progress monitoring" log_step "4.2. Starting smart WebSocket monitoring"
test_websocket_with_nodejs_builtin "$TENANT_ID" "$WEBSOCKET_JOB_ID" "$WS_TEST_DURATION" echo " Strategy: Monitor until job completion"
echo " Maximum wait time: ${WS_TEST_DURATION}s (safety timeout)"
echo " Will automatically close when training completes"
echo ""
# Use enhanced monitoring with longer timeout for real training
local SMART_DURATION=$WS_TEST_DURATION
# Estimate duration based on data size (optional enhancement)
if [ -n "$SALES_RECORDS" ] && [ "$SALES_RECORDS" -gt 1000 ]; then
# For large datasets, extend timeout
SMART_DURATION=$((WS_TEST_DURATION * 2))
echo " 📊 Large dataset detected ($SALES_RECORDS records)"
echo " 🕐 Extended timeout to ${SMART_DURATION}s for thorough training"
fi
test_websocket_with_nodejs_builtin "$TENANT_ID" "$WEBSOCKET_JOB_ID" "$SMART_DURATION"
fi fi
else else
@@ -786,41 +827,6 @@ services_check() {
done done
} }
check_websocket_prerequisites() {
echo -e "${PURPLE}🔍 Checking WebSocket testing prerequisites...${NC}"
# Check for websocat
if command -v websocat >/dev/null 2>&1; then
log_success "websocat found - will use for WebSocket testing"
return 0
fi
# Check for Node.js
if command -v node >/dev/null 2>&1; then
local node_version=$(node --version 2>/dev/null || echo "unknown")
log_success "Node.js found ($node_version) - will use for WebSocket testing"
# Check if ws module is available (try to require it)
if node -e "require('ws')" 2>/dev/null; then
log_success "Node.js 'ws' module available"
else
log_warning "Node.js 'ws' module not found"
echo " Install with: npm install -g ws"
echo " Will attempt to use built-in functionality..."
fi
return 0
fi
log_warning "Neither websocat nor Node.js found"
echo " WebSocket testing will use HTTP polling fallback"
echo " For better testing, install one of:"
echo " • websocat: cargo install websocat"
echo " • Node.js: https://nodejs.org/"
return 1
}
services_check services_check
echo "" echo ""
@@ -1167,7 +1173,7 @@ echo ""
# STEP 4: MODEL TRAINING (ONBOARDING PAGE STEP 4) # STEP 4: MODEL TRAINING (ONBOARDING PAGE STEP 4)
# ================================================================= # =================================================================
check_websocket_prerequisites test_websocket_connection
enhanced_training_step_with_completion_check enhanced_training_step_with_completion_check