diff --git a/services/training/app/api/websocket.py b/services/training/app/api/websocket.py index 2d75be35..60184651 100644 --- a/services/training/app/api/websocket.py +++ b/services/training/app/api/websocket.py @@ -5,19 +5,19 @@ WebSocket endpoints for real-time training progress updates import json import asyncio -import logging from typing import Dict, Any from fastapi import WebSocket, WebSocketDisconnect, Depends, HTTPException from fastapi.routing import APIRouter +import structlog +logger = structlog.get_logger(__name__) + from app.services.messaging import training_publisher from shared.auth.decorators import ( get_current_user_dep, get_current_tenant_id_dep ) -logger = logging.getLogger(__name__) - # Create WebSocket router websocket_router = APIRouter() @@ -48,8 +48,9 @@ class ConnectionManager: logger.info(f"WebSocket disconnected for job {job_id}, connection {connection_id}") async def send_to_job(self, job_id: str, message: dict): - """Send message to all connections for a specific job""" + """Send message to all connections for a specific job with better error handling""" if job_id not in self.active_connections: + logger.debug(f"No active connections for job {job_id}") return # Send to all connections for this job @@ -58,6 +59,7 @@ class ConnectionManager: for connection_id, websocket in self.active_connections[job_id].items(): try: await websocket.send_json(message) + logger.debug(f"šŸ“¤ Sent {message.get('type', 'unknown')} to connection {connection_id}") except Exception as e: logger.warning(f"Failed to send message to connection {connection_id}: {e}") disconnected_connections.append(connection_id) @@ -65,6 +67,11 @@ class ConnectionManager: # Clean up disconnected connections for connection_id in disconnected_connections: self.disconnect(job_id, connection_id) + + # Log successful sends + active_count = len(self.active_connections.get(job_id, {})) + if active_count > 0: + logger.info(f"šŸ“” Sent {message.get('type', 'unknown')} message to {active_count} connection(s) for job {job_id}") # Global connection manager connection_manager = ConnectionManager() @@ -76,40 +83,31 @@ async def training_progress_websocket( job_id: str ): """ - WebSocket endpoint for real-time training progress updates - - Message format sent to client: - { - "type": "progress" | "completed" | "failed" | "step_completed", - "job_id": "job_123", - "timestamp": "2025-07-30T19:08:53Z", - "data": { - "progress": 45, - "current_step": "Training model for bread", - "current_product": "bread", - "products_completed": 2, - "products_total": 5, - "estimated_time_remaining_minutes": 8 - } - } + FIXED WebSocket endpoint for real-time training progress updates + Prevents premature disconnection during training """ connection_id = f"{tenant_id}_{id(websocket)}" # Accept connection await connection_manager.connect(websocket, job_id, connection_id) + logger.info(f"WebSocket connection established for job {job_id}") # Set up RabbitMQ consumer for this job consumer_task = None + training_completed = False # Track training completion try: # Start RabbitMQ consumer + logger.info(f"Setting up RabbitMQ consumer for job {job_id}") consumer_task = asyncio.create_task( setup_rabbitmq_consumer_for_job(job_id, tenant_id) ) + # Give consumer time to set up + await asyncio.sleep(0.5) + # Send initial status if available try: - # You can fetch current job status from database here initial_status = await get_current_job_status(job_id, tenant_id) if initial_status: await websocket.send_json({ @@ -117,48 +115,92 @@ async def training_progress_websocket( "job_id": job_id, "data": initial_status }) + logger.info(f"Sent initial status for job {job_id}") except Exception as e: logger.warning(f"Failed to send initial status: {e}") - # Keep connection alive and handle client messages - while True: + # Keep connection alive - IMPROVED ERROR HANDLING + last_activity = asyncio.get_event_loop().time() + + while not training_completed: try: - # Wait for client ping or other messages - message = await websocket.receive_text() - - if message == "ping": - await websocket.send_text("pong") - elif message == "get_status": - # Send current status on demand - current_status = await get_current_job_status(job_id, tenant_id) - if current_status: + # Wait for client messages with timeout + try: + message = await asyncio.wait_for(websocket.receive_text(), timeout=30.0) + last_activity = asyncio.get_event_loop().time() + + if message == "ping": + await websocket.send_text("pong") + logger.debug(f"Ping received from job {job_id}") + elif message == "get_status": + current_status = await get_current_job_status(job_id, tenant_id) + if current_status: + await websocket.send_json({ + "type": "current_status", + "job_id": job_id, + "data": current_status + }) + elif message == "close": + logger.info(f"Client requested connection close for job {job_id}") + break + + except asyncio.TimeoutError: + # No message received in 30 seconds - send heartbeat + current_time = asyncio.get_event_loop().time() + if current_time - last_activity > 60: # 60 seconds of inactivity + logger.warning(f"Long inactivity period for job {job_id}, sending heartbeat") + + try: await websocket.send_json({ - "type": "current_status", + "type": "heartbeat", "job_id": job_id, - "data": current_status + "timestamp": datetime.utcnow().isoformat() }) + except Exception as e: + logger.error(f"Failed to send heartbeat for job {job_id}: {e}") + break except WebSocketDisconnect: - logger.info(f"WebSocket disconnected for job {job_id}") + logger.info(f"WebSocket client disconnected for job {job_id}") break except Exception as e: logger.error(f"WebSocket error for job {job_id}: {e}") - break + # Don't break immediately - try to recover + await asyncio.sleep(1) + logger.info(f"WebSocket loop ended for job {job_id}, training_completed: {training_completed}") + + except Exception as e: + logger.error(f"Critical WebSocket error for job {job_id}: {e}") + finally: - # Clean up + # IMPROVED CLEANUP - Don't cancel consumer unless truly disconnecting + logger.info(f"Cleaning up WebSocket connection for job {job_id}") connection_manager.disconnect(job_id, connection_id) + # Only cancel consumer if we're truly done (not just a temporary error) if consumer_task and not consumer_task.done(): - consumer_task.cancel() + if training_completed: + logger.info(f"Training completed, cancelling consumer for job {job_id}") + consumer_task.cancel() + else: + logger.warning(f"WebSocket disconnected but training not completed for job {job_id}") + # Let the consumer continue running for other potential connections + # Don't cancel it unless we're sure the job is done + try: await consumer_task except asyncio.CancelledError: - pass + logger.info(f"Consumer task cancelled for job {job_id}") + except Exception as e: + logger.error(f"Consumer task error for job {job_id}: {e}") + async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str): """Set up RabbitMQ consumer to listen for training events for a specific job""" + logger.info(f"šŸš€ Setting up RabbitMQ consumer for job {job_id}") + try: # Create a unique queue for this WebSocket connection queue_name = f"websocket_training_{job_id}_{tenant_id}" @@ -170,12 +212,16 @@ async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str): body = message.body.decode() data = json.loads(body) + logger.debug(f"šŸ” Received message for job {job_id}: {data.get('event_type', 'unknown')}") + # Extract event data event_type = data.get("event_type", "unknown") event_data = data.get("data", {}) # Only process messages for this specific job - if event_data.get("job_id") != job_id: + message_job_id = event_data.get("job_id") if event_data else None + if message_job_id != job_id: + logger.debug(f"ā­ļø Ignoring message for different job: {message_job_id}") await message.ack() return @@ -187,36 +233,76 @@ async def setup_rabbitmq_consumer_for_job(job_id: str, tenant_id: str): "data": event_data } + logger.info(f"šŸ“¤ Forwarding {event_type} message to WebSocket clients for job {job_id}") + # Send to all WebSocket connections for this job await connection_manager.send_to_job(job_id, websocket_message) + # Check if this is a completion message + if event_type in ["training.completed", "training.failed"]: + logger.info(f"šŸŽÆ Training completion detected for job {job_id}: {event_type}") + # Mark training as completed (you might want to store this in a global state) + # For now, we'll let the WebSocket handle this through the message + # Acknowledge the message await message.ack() - logger.debug(f"Forwarded training event to WebSocket: {event_type}") + logger.debug(f"āœ… Successfully processed {event_type} for job {job_id}") except Exception as e: - logger.error(f"Error handling training message for WebSocket: {e}") + logger.error(f"āŒ Error handling training message for job {job_id}: {e}") + import traceback + logger.error(f"šŸ’„ Traceback: {traceback.format_exc()}") await message.nack(requeue=False) + # Check if training_publisher is connected + if not training_publisher.connected: + logger.warning(f"āš ļø Training publisher not connected for job {job_id}, attempting to connect...") + success = await training_publisher.connect() + if not success: + logger.error(f"āŒ Failed to connect training_publisher for job {job_id}") + return + # Subscribe to training events - await training_publisher.consume_events( + logger.info(f"šŸ”— Subscribing to training events for job {job_id}") + success = await training_publisher.consume_events( exchange_name="training.events", queue_name=queue_name, routing_key="training.*", # Listen to all training events callback=handle_training_message ) + if success: + logger.info(f"āœ… Successfully set up RabbitMQ consumer for job {job_id} (queue: {queue_name})") + + # Keep the consumer running indefinitely until cancelled + try: + while True: + await asyncio.sleep(10) # Keep consumer alive + logger.debug(f"šŸ”„ Consumer heartbeat for job {job_id}") + + except asyncio.CancelledError: + logger.info(f"šŸ›‘ Consumer cancelled for job {job_id}") + raise + except Exception as e: + logger.error(f"šŸ’„ Consumer error for job {job_id}: {e}") + raise + else: + logger.error(f"āŒ Failed to set up RabbitMQ consumer for job {job_id}") + except Exception as e: - logger.error(f"Failed to set up RabbitMQ consumer for WebSocket: {e}") + logger.error(f"šŸ’„ Exception in setup_rabbitmq_consumer_for_job for job {job_id}: {e}") + import traceback + logger.error(f"šŸ”„ Traceback: {traceback.format_exc()}") + def map_event_type_to_websocket_type(rabbitmq_event_type: str) -> str: """Map RabbitMQ event types to WebSocket message types""" mapping = { "training.started": "started", "training.progress": "progress", - "training.completed": "completed", - "training.failed": "failed", + "training.completed": "completed", # This is the key completion event + "training.failed": "failed", # This is also a completion event "training.cancelled": "cancelled", "training.step.completed": "step_completed", "training.product.started": "product_started", diff --git a/tests/test_onboarding_flow.sh b/tests/test_onboarding_flow.sh index 3a76af02..bf16c94e 100755 --- a/tests/test_onboarding_flow.sh +++ b/tests/test_onboarding_flow.sh @@ -13,7 +13,7 @@ TEST_PASSWORD="TestPassword123!" TEST_NAME="Test Bakery Owner" REAL_CSV_FILE="bakery_sales_2023_2024.csv" WS_BASE="ws://localhost:8002/api/v1/ws" -WS_TEST_DURATION=200 # seconds to listen for WebSocket messages +WS_TEST_DURATION=2000 # seconds to listen for WebSocket messages WS_PID="" @@ -158,20 +158,429 @@ check_timezone_error() { return 1 # No timezone error } -# Function to test WebSocket connection using websocat (if available) or Node.js +test_websocket_with_nodejs_builtin() { + local tenant_id="$1" + local job_id="$2" + local max_duration="$3" # Maximum time to wait (fallback) + + echo "Using Node.js with built-in modules for WebSocket testing..." + echo "Will monitor until job completion or ${max_duration}s timeout" + + # Create ENHANCED Node.js WebSocket test script + local ws_test_script="/tmp/websocket_test_$job_id.js" + cat > "$ws_test_script" << 'EOF' +// ENHANCED WebSocket test - waits for job completion +const https = require('https'); +const http = require('http'); +const crypto = require('crypto'); + +const tenantId = process.argv[2]; +const jobId = process.argv[3]; +const maxDuration = parseInt(process.argv[4]) * 1000; // Convert to milliseconds +const accessToken = process.argv[5]; +const wsUrl = process.argv[6]; + +console.log(`šŸš€ Starting enhanced WebSocket monitoring`); +console.log(`Connecting to: ${wsUrl}`); +console.log(`Will wait for job completion (max ${maxDuration/1000}s)`); + +// Parse WebSocket URL +const url = new URL(wsUrl); +const isSecure = url.protocol === 'wss:'; +const port = url.port || (isSecure ? 443 : 80); + +// Create WebSocket key +const key = crypto.randomBytes(16).toString('base64'); + +// WebSocket handshake headers +const headers = { + 'Upgrade': 'websocket', + 'Connection': 'Upgrade', + 'Sec-WebSocket-Key': key, + 'Sec-WebSocket-Version': '13', + 'Authorization': `Bearer ${accessToken}` +}; + +const options = { + hostname: url.hostname, + port: port, + path: url.pathname, + method: 'GET', + headers: headers +}; + +console.log(`Attempting WebSocket handshake to ${url.hostname}:${port}${url.pathname}`); + +const client = isSecure ? https : http; +let messageCount = 0; +let jobCompleted = false; +let lastProgressUpdate = Date.now(); +let highestProgress = 0; + +// Enhanced job tracking +const jobStats = { + startTime: Date.now(), + progressUpdates: 0, + stepsCompleted: [], + productsProcessed: [], + errors: [] +}; + +const req = client.request(options); + +req.on('upgrade', (res, socket, head) => { + console.log('āœ… WebSocket handshake successful'); + console.log('šŸ“” Monitoring training progress...\n'); + + let buffer = Buffer.alloc(0); + + socket.on('data', (data) => { + buffer = Buffer.concat([buffer, data]); + + // WebSocket frame parsing + while (buffer.length >= 2) { + const firstByte = buffer[0]; + const secondByte = buffer[1]; + + const fin = (firstByte & 0x80) === 0x80; + const opcode = firstByte & 0x0F; + const masked = (secondByte & 0x80) === 0x80; + let payloadLength = secondByte & 0x7F; + + let offset = 2; + + // Handle extended payload length + if (payloadLength === 126) { + if (buffer.length < offset + 2) break; + payloadLength = buffer.readUInt16BE(offset); + offset += 2; + } else if (payloadLength === 127) { + if (buffer.length < offset + 8) break; + const high = buffer.readUInt32BE(offset); + const low = buffer.readUInt32BE(offset + 4); + if (high !== 0) { + console.log('āš ļø Large payload detected, skipping...'); + buffer = buffer.slice(offset + 8); + continue; + } + payloadLength = low; + offset += 8; + } + + // Check if we have the complete frame + if (buffer.length < offset + payloadLength) { + break; // Wait for more data + } + + // Extract payload + const payload = buffer.slice(offset, offset + payloadLength); + buffer = buffer.slice(offset + payloadLength); + + // Handle different frame types + if (opcode === 1 && fin) { // Text frame + messageCount++; + lastProgressUpdate = Date.now(); + const timestamp = new Date().toLocaleTimeString(); + + try { + const messageText = payload.toString('utf8'); + const message = JSON.parse(messageText); + + // Enhanced message processing + processTrainingMessage(message, timestamp); + + } catch (e) { + const rawText = payload.toString('utf8'); + console.log(`[${timestamp}] āš ļø Raw message: ${rawText.substring(0, 200)}${rawText.length > 200 ? '...' : ''}`); + } + + } else if (opcode === 8) { // Close frame + console.log('šŸ”Œ WebSocket closed by server'); + socket.end(); + return; + + } else if (opcode === 9) { // Ping frame + // Send pong response + const pongFrame = Buffer.concat([ + Buffer.from([0x8A, payload.length]), + payload + ]); + socket.write(pongFrame); + + } else if (opcode === 10) { // Pong frame + // Ignore pong responses + continue; + } + } + }); + + // Enhanced message processing function + function processTrainingMessage(message, timestamp) { + const messageType = message.type || 'unknown'; + const data = message.data || {}; + + console.log(`[${timestamp}] šŸ“Ø Message ${messageCount}: ${messageType.toUpperCase()}`); + + // Track job statistics + if (messageType === 'progress') { + jobStats.progressUpdates++; + const progress = data.progress || 0; + const step = data.current_step || 'Unknown step'; + const product = data.current_product; + + // Update highest progress + if (progress > highestProgress) { + highestProgress = progress; + } + + // Track steps + if (step && !jobStats.stepsCompleted.includes(step)) { + jobStats.stepsCompleted.push(step); + } + + // Track products + if (product && !jobStats.productsProcessed.includes(product)) { + jobStats.productsProcessed.push(product); + } + + // Display progress with enhanced formatting + console.log(` šŸ“Š Progress: ${progress}% (${step})`); + if (product) { + console.log(` šŸž Product: ${product}`); + } + if (data.products_completed && data.products_total) { + console.log(` šŸ“¦ Products: ${data.products_completed}/${data.products_total} completed`); + } + if (data.estimated_time_remaining_minutes) { + console.log(` ā±ļø ETA: ${data.estimated_time_remaining_minutes} minutes`); + } + + } else if (messageType === 'completed') { + jobCompleted = true; + const duration = Math.round((Date.now() - jobStats.startTime) / 1000); + + console.log(`\nšŸŽ‰ TRAINING COMPLETED SUCCESSFULLY!`); + console.log(` ā±ļø Total Duration: ${duration}s`); + + if (data.results) { + const results = data.results; + if (results.successful_trainings !== undefined) { + console.log(` āœ… Models Trained: ${results.successful_trainings}`); + } + if (results.total_products !== undefined) { + console.log(` šŸ“¦ Total Products: ${results.total_products}`); + } + if (results.success_rate !== undefined) { + console.log(` šŸ“ˆ Success Rate: ${results.success_rate}%`); + } + } + + // Close connection after completion + setTimeout(() => { + console.log('\nšŸ“Š Training job completed - closing WebSocket connection'); + socket.end(); + }, 2000); // Wait 2 seconds to ensure all final messages are received + + } else if (messageType === 'failed') { + jobCompleted = true; + jobStats.errors.push(data); + + console.log(`\nāŒ TRAINING FAILED!`); + if (data.error) { + console.log(` šŸ’„ Error: ${data.error}`); + } + if (data.error_details) { + console.log(` šŸ“ Details: ${JSON.stringify(data.error_details, null, 2)}`); + } + + // Close connection after failure + setTimeout(() => { + console.log('\nšŸ“Š Training job failed - closing WebSocket connection'); + socket.end(); + }, 2000); + + } else if (messageType === 'step_completed') { + console.log(` āœ… Step completed: ${data.step_name || 'Unknown'}`); + + } else if (messageType === 'product_started') { + console.log(` šŸš€ Started training: ${data.product_name || 'Unknown product'}`); + + } else if (messageType === 'product_completed') { + console.log(` āœ… Product completed: ${data.product_name || 'Unknown product'}`); + if (data.metrics) { + console.log(` šŸ“Š Metrics: ${JSON.stringify(data.metrics, null, 2)}`); + } + } + + console.log(''); // Add spacing between messages + } + + socket.on('end', () => { + const duration = Math.round((Date.now() - jobStats.startTime) / 1000); + + console.log(`\nšŸ“Š WebSocket connection ended`); + console.log(`šŸ“Ø Total messages received: ${messageCount}`); + console.log(`ā±ļø Connection duration: ${duration}s`); + console.log(`šŸ“ˆ Highest progress reached: ${highestProgress}%`); + + if (jobCompleted) { + console.log('āœ… Job completed successfully - connection closed normally'); + process.exit(0); + } else { + console.log('āš ļø Connection ended before job completion'); + console.log(`šŸ“Š Progress reached: ${highestProgress}%`); + console.log(`šŸ“‹ Steps completed: ${jobStats.stepsCompleted.length}`); + process.exit(1); + } + }); + + socket.on('error', (error) => { + console.log(`āŒ WebSocket error: ${error.message}`); + process.exit(1); + }); + + // Enhanced ping mechanism - send pings more frequently + const pingInterval = setInterval(() => { + if (socket.writable && !jobCompleted) { + try { + const pingFrame = Buffer.from([0x89, 0x00]); + socket.write(pingFrame); + } catch (e) { + // Ignore ping errors + } + } + }, 5000); // Ping every 5 seconds + + // Heartbeat check - ensure we're still receiving messages + const heartbeatInterval = setInterval(() => { + if (!jobCompleted) { + const timeSinceLastMessage = Date.now() - lastProgressUpdate; + + if (timeSinceLastMessage > 60000) { // 60 seconds without messages + console.log('\nāš ļø No messages received for 60 seconds'); + console.log(' This could indicate the training is stuck or connection issues'); + console.log(` Last progress: ${highestProgress}%`); + } else if (timeSinceLastMessage > 30000) { // 30 seconds warning + console.log(`\nšŸ’¤ Quiet period: ${Math.round(timeSinceLastMessage/1000)}s since last update`); + console.log(' (This is normal during intensive training phases)'); + } + } + }, 15000); // Check every 15 seconds + + // Safety timeout - close connection if max duration exceeded + const safetyTimeout = setTimeout(() => { + if (!jobCompleted) { + clearInterval(pingInterval); + clearInterval(heartbeatInterval); + + console.log(`\nā° Maximum duration (${maxDuration/1000}s) reached`); + console.log(`šŸ“Š Final status:`); + console.log(` šŸ“Ø Messages received: ${messageCount}`); + console.log(` šŸ“ˆ Progress reached: ${highestProgress}%`); + console.log(` šŸ“‹ Steps completed: ${jobStats.stepsCompleted.length}`); + console.log(` šŸž Products processed: ${jobStats.productsProcessed.length}`); + + if (messageCount > 0) { + console.log('\nāœ… WebSocket communication was successful!'); + console.log(' Training may still be running - check server logs for completion'); + } else { + console.log('\nāš ļø No messages received during monitoring period'); + } + + socket.end(); + } + }, maxDuration); + + // Clean up intervals when job completes + socket.on('end', () => { + clearInterval(pingInterval); + clearInterval(heartbeatInterval); + clearTimeout(safetyTimeout); + }); +}); + +req.on('response', (res) => { + console.log(`āŒ HTTP response instead of WebSocket upgrade: ${res.statusCode}`); + console.log('Response headers:', res.headers); + + let body = ''; + res.on('data', chunk => body += chunk); + res.on('end', () => { + if (body) console.log('Response body:', body); + process.exit(1); + }); +}); + +req.on('error', (error) => { + console.log(`āŒ Connection error: ${error.message}`); + process.exit(1); +}); + +req.end(); +EOF + + # Run the ENHANCED Node.js WebSocket test + local ws_url="$WS_BASE/tenants/$tenant_id/training/jobs/$job_id/live" + echo "Starting enhanced WebSocket monitoring..." + node "$ws_test_script" "$tenant_id" "$job_id" "$max_duration" "$ACCESS_TOKEN" "$ws_url" + local exit_code=$? + + # Clean up + rm -f "$ws_test_script" + + if [ $exit_code -eq 0 ]; then + log_success "Training job completed successfully!" + echo " šŸ“” WebSocket monitoring detected job completion" + echo " šŸŽ‰ Real-time progress tracking worked perfectly" + else + log_warning "WebSocket monitoring ended before job completion" + echo " šŸ“Š Check the progress logs above for details" + fi + + return $exit_code +} + + +install_websocat_if_needed() { + if ! command -v websocat >/dev/null 2>&1; then + echo "šŸ“¦ Installing websocat for better WebSocket testing..." + + # Try to install websocat (works on most Linux systems) + if command -v cargo >/dev/null 2>&1; then + cargo install websocat 2>/dev/null || true + elif [ -x "$(command -v wget)" ]; then + wget -q -O /tmp/websocat "https://github.com/vi/websocat/releases/latest/download/websocat.x86_64-unknown-linux-musl" 2>/dev/null || true + if [ -f /tmp/websocat ]; then + chmod +x /tmp/websocat + sudo mv /tmp/websocat /usr/local/bin/ 2>/dev/null || mv /tmp/websocat ~/bin/ 2>/dev/null || true + fi + fi + + if command -v websocat >/dev/null 2>&1; then + log_success "websocat installed successfully" + return 0 + else + log_warning "websocat installation failed, using Node.js fallback" + return 1 + fi + fi + return 0 +} + +# IMPROVED: WebSocket connection function with better tool selection test_websocket_connection() { local tenant_id="$1" local job_id="$2" local duration="$3" - log_step "4.2. Testing WebSocket connection for real-time training progress" + log_step "4.2. Connecting to WebSocket for real-time progress monitoring" echo "WebSocket URL: $WS_BASE/tenants/$tenant_id/training/jobs/$job_id/live" echo "Test duration: ${duration}s" echo "" - # Check if websocat is available - if command -v websocat >/dev/null 2>&1; then + # Try to install websocat if not available + if install_websocat_if_needed; then test_websocket_with_websocat "$tenant_id" "$job_id" "$duration" elif command -v node >/dev/null 2>&1; then test_websocket_with_nodejs_builtin "$tenant_id" "$job_id" "$duration" @@ -221,382 +630,6 @@ test_websocket_with_websocat() { fi } -# Test WebSocket using Node.js -test_websocket_with_nodejs() { - local tenant_id="$1" - local job_id="$2" - local duration="$3" - - echo "Using Node.js for WebSocket testing..." - - # Create Node.js WebSocket test script - local ws_test_script="/tmp/websocket_test_$job_id.js" - cat > "$ws_test_script" << 'EOF' -const WebSocket = require('ws'); - -const tenantId = process.argv[2]; -const jobId = process.argv[3]; -const duration = parseInt(process.argv[4]) * 1000; -const accessToken = process.argv[5]; -const wsUrl = process.argv[6]; - -console.log(`Connecting to: ${wsUrl}`); - -const ws = new WebSocket(wsUrl, { - headers: { - 'Authorization': `Bearer ${accessToken}` - } -}); - -let messageCount = 0; -let startTime = Date.now(); - -ws.on('open', function() { - console.log('āœ… WebSocket connected successfully'); - - // Send periodic pings - const pingInterval = setInterval(() => { - if (ws.readyState === WebSocket.OPEN) { - ws.send('ping'); - } - }, 5000); - - // Close after duration - setTimeout(() => { - clearInterval(pingInterval); - console.log(`\nšŸ“Š WebSocket test completed after ${duration/1000}s`); - console.log(`šŸ“Ø Total messages received: ${messageCount}`); - if (messageCount > 0) { - console.log('āœ… WebSocket communication successful'); - } else { - console.log('āš ļø No training progress messages received'); - console.log(' This may be normal if training completed quickly'); - } - ws.close(); - process.exit(0); - }, duration); -}); - -ws.on('message', function(data) { - messageCount++; - const timestamp = new Date().toLocaleTimeString(); - - try { - const message = JSON.parse(data); - console.log(`\n[${timestamp}] šŸ“Ø Message ${messageCount}:`); - console.log(` Type: ${message.type || 'unknown'}`); - console.log(` Job ID: ${message.job_id || 'unknown'}`); - - if (message.data) { - if (message.data.progress !== undefined) { - console.log(` Progress: ${message.data.progress}%`); - } - if (message.data.current_step) { - console.log(` Step: ${message.data.current_step}`); - } - if (message.data.current_product) { - console.log(` Product: ${message.data.current_product}`); - } - if (message.data.estimated_time_remaining_minutes) { - console.log(` ETA: ${message.data.estimated_time_remaining_minutes} minutes`); - } - } - - // Special handling for completion messages - if (message.type === 'completed') { - console.log('šŸŽ‰ Training completed!'); - } else if (message.type === 'failed') { - console.log('āŒ Training failed!'); - } - - } catch (e) { - console.log(`[${timestamp}] Raw message: ${data}`); - } -}); - -ws.on('error', function(error) { - console.log('āŒ WebSocket error:', error.message); -}); - -ws.on('close', function(code, reason) { - console.log(`\nšŸ”Œ WebSocket closed (code: ${code}, reason: ${reason || 'normal'})`); - process.exit(code === 1000 ? 0 : 1); -}); -EOF - - # Run Node.js WebSocket test - local ws_url="$WS_BASE/tenants/$tenant_id/training/jobs/$job_id/live" - node "$ws_test_script" "$tenant_id" "$job_id" "$duration" "$ACCESS_TOKEN" "$ws_url" & - WS_PID=$! - - # Wait for completion - wait $WS_PID - local exit_code=$? - - # Clean up - rm -f "$ws_test_script" - - if [ $exit_code -eq 0 ]; then - log_success "WebSocket test completed successfully" - else - log_warning "WebSocket test completed with issues" - fi -} - -test_websocket_with_nodejs_builtin() { - local tenant_id="$1" - local job_id="$2" - local duration="$3" - - echo "Using Node.js with built-in modules for WebSocket testing..." - - # Create Node.js WebSocket test script using built-in modules only - local ws_test_script="/tmp/websocket_test_$job_id.js" - cat > "$ws_test_script" << 'EOF' -// WebSocket test using only built-in Node.js modules -const { WebSocket } = require('node:http'); -const https = require('https'); -const http = require('http'); -const crypto = require('crypto'); - -const tenantId = process.argv[2]; -const jobId = process.argv[3]; -const duration = parseInt(process.argv[4]) * 1000; -const accessToken = process.argv[5]; -const wsUrl = process.argv[6]; - -console.log(`Connecting to: ${wsUrl}`); -console.log(`Duration: ${duration/1000}s`); - -// Parse WebSocket URL -const url = new URL(wsUrl); -const isSecure = url.protocol === 'wss:'; -const port = url.port || (isSecure ? 443 : 80); - -// Create WebSocket key (required for WebSocket handshake) -const key = crypto.randomBytes(16).toString('base64'); - -// WebSocket handshake headers -const headers = { - 'Upgrade': 'websocket', - 'Connection': 'Upgrade', - 'Sec-WebSocket-Key': key, - 'Sec-WebSocket-Version': '13', - 'Authorization': `Bearer ${accessToken}` -}; - -const options = { - hostname: url.hostname, - port: port, - path: url.pathname, - method: 'GET', - headers: headers -}; - -console.log(`Attempting WebSocket handshake to ${url.hostname}:${port}${url.pathname}`); - -const client = isSecure ? https : http; -let messageCount = 0; -let startTime = Date.now(); - -const req = client.request(options); - -req.on('upgrade', (res, socket, head) => { - console.log('āœ… WebSocket handshake successful'); - - let buffer = ''; - - socket.on('data', (data) => { - buffer += data.toString(); - - // Process complete WebSocket frames - while (buffer.length > 0) { - // Simple WebSocket frame parsing (for text frames) - if (buffer.length < 2) break; - - const firstByte = buffer.charCodeAt(0); - const secondByte = buffer.charCodeAt(1); - - const opcode = firstByte & 0x0F; - const masked = (secondByte & 0x80) === 0x80; - let payloadLength = secondByte & 0x7F; - - let offset = 2; - - if (payloadLength === 126) { - if (buffer.length < offset + 2) break; - payloadLength = (buffer.charCodeAt(offset) << 8) | buffer.charCodeAt(offset + 1); - offset += 2; - } else if (payloadLength === 127) { - if (buffer.length < offset + 8) break; - // For simplicity, assume payload length fits in 32 bits - payloadLength = (buffer.charCodeAt(offset + 4) << 24) | - (buffer.charCodeAt(offset + 5) << 16) | - (buffer.charCodeAt(offset + 6) << 8) | - buffer.charCodeAt(offset + 7); - offset += 8; - } - - if (buffer.length < offset + payloadLength) break; - - // Extract payload - let payload = buffer.slice(offset, offset + payloadLength); - buffer = buffer.slice(offset + payloadLength); - - if (opcode === 1) { // Text frame - messageCount++; - const timestamp = new Date().toLocaleTimeString(); - - try { - const message = JSON.parse(payload); - console.log(`\n[${timestamp}] šŸ“Ø Message ${messageCount}:`); - console.log(` Type: ${message.type || 'unknown'}`); - console.log(` Job ID: ${message.job_id || 'unknown'}`); - - if (message.data) { - if (message.data.progress !== undefined) { - console.log(` Progress: ${message.data.progress}%`); - } - if (message.data.current_step) { - console.log(` Step: ${message.data.current_step}`); - } - } - - if (message.type === 'completed') { - console.log('šŸŽ‰ Training completed!'); - } else if (message.type === 'failed') { - console.log('āŒ Training failed!'); - } - - } catch (e) { - console.log(`[${timestamp}] Raw message: ${payload}`); - } - } else if (opcode === 8) { // Close frame - console.log('šŸ”Œ WebSocket closed by server'); - socket.end(); - return; - } - } - }); - - socket.on('end', () => { - console.log(`\nšŸ“Š WebSocket test completed`); - console.log(`šŸ“Ø Total messages received: ${messageCount}`); - if (messageCount > 0) { - console.log('āœ… WebSocket communication successful'); - } else { - console.log('āš ļø No messages received during test period'); - } - process.exit(0); - }); - - socket.on('error', (error) => { - console.log('āŒ WebSocket error:', error.message); - process.exit(1); - }); - - // Send periodic pings to keep connection alive - const pingInterval = setInterval(() => { - if (socket.writable) { - // Send ping frame (opcode 9) - const pingFrame = Buffer.from([0x89, 0x00]); - socket.write(pingFrame); - } - }, 10000); - - // Close after duration - setTimeout(() => { - clearInterval(pingInterval); - console.log(`\nā° Test duration (${duration/1000}s) completed`); - console.log(`šŸ“Ø Total messages received: ${messageCount}`); - - if (messageCount > 0) { - console.log('āœ… WebSocket communication successful'); - } else { - console.log('āš ļø No training progress messages received'); - console.log(' This is normal if training completed before WebSocket connection'); - } - - socket.end(); - process.exit(0); - }, duration); -}); - -req.on('response', (res) => { - console.log(`āŒ HTTP response instead of WebSocket upgrade: ${res.statusCode}`); - console.log('Response headers:', res.headers); - - let body = ''; - res.on('data', chunk => body += chunk); - res.on('end', () => { - if (body) console.log('Response body:', body); - }); - - process.exit(1); -}); - -req.on('error', (error) => { - console.log('āŒ Connection error:', error.message); - process.exit(1); -}); - -req.end(); -EOF - - # Run the Node.js WebSocket test - local ws_url="$WS_BASE/tenants/$tenant_id/training/jobs/$job_id/live" - echo "Starting WebSocket test..." - node "$ws_test_script" "$tenant_id" "$job_id" "$duration" "$ACCESS_TOKEN" "$ws_url" - local exit_code=$? - - # Clean up - rm -f "$ws_test_script" - - return $exit_code -} - -# Fallback: Test WebSocket using curl (limited functionality) -test_websocket_with_curl() { - local tenant_id="$1" - local job_id="$2" - local duration="$3" - - log_warning "WebSocket testing tools not available (websocat/node.js)" - echo "Falling back to HTTP polling simulation..." - - # Create a simple HTTP-based progress polling simulation - local poll_endpoint="$API_BASE/api/v1/tenants/$tenant_id/training/jobs/$job_id/status" - local end_time=$(($(date +%s) + duration)) - local poll_count=0 - - echo "Simulating real-time updates by polling: $poll_endpoint" - echo "Duration: ${duration}s" - - while [ $(date +%s) -lt $end_time ]; do - poll_count=$((poll_count + 1)) - echo "" - echo "[$(date '+%H:%M:%S')] Poll #$poll_count - Checking training status..." - - STATUS_RESPONSE=$(curl -s -X GET "$poll_endpoint" \ - -H "Authorization: Bearer $ACCESS_TOKEN" \ - -H "X-Tenant-ID: $tenant_id") - - echo "Response:" - echo "$STATUS_RESPONSE" | python3 -m json.tool 2>/dev/null || echo "$STATUS_RESPONSE" - - # Check if training is complete - if echo "$STATUS_RESPONSE" | grep -q '"status".*"completed"\|"status".*"failed"'; then - log_success "Training status detected as complete/failed - stopping polling" - break - fi - - sleep 5 - done - - log_success "HTTP polling simulation completed ($poll_count polls)" - echo "šŸ’” For real WebSocket testing, install: npm install -g websocat" -} - # Wait for WebSocket messages and analyze them wait_for_websocket_messages() { local ws_log="$1" @@ -660,8 +693,8 @@ wait_for_websocket_messages() { # Enhanced training step with WebSocket testing enhanced_training_step_with_completion_check() { - echo -e "${STEP_ICONS[3]} ${PURPLE}STEP 4: MODEL TRAINING WITH WEBSOCKET MONITORING${NC}" - echo "Enhanced training step with real-time progress monitoring" + echo -e "${STEP_ICONS[3]} ${PURPLE}STEP 4: MODEL TRAINING WITH SMART WEBSOCKET MONITORING${NC}" + echo "Enhanced training step with completion-aware progress monitoring" echo "" log_step "4.1. Initiating model training with FULL dataset" @@ -694,42 +727,50 @@ enhanced_training_step_with_completion_check() { echo " Job ID: $WEBSOCKET_JOB_ID" echo " Status: $JOB_STATUS" - # Check if training completed instantly + # Determine monitoring strategy based on initial status if [ "$JOB_STATUS" = "completed" ]; then log_warning "Training completed instantly - no real-time progress to monitor" echo " This can happen when:" + echo " • Models are already trained and cached" echo " • No valid products found in sales data" echo " • Training data is insufficient" - echo " • Models are already trained and cached" - echo "" # Show training results TOTAL_PRODUCTS=$(extract_json_field "$TRAINING_RESPONSE" "training_results.total_products") SUCCESSFUL_TRAININGS=$(extract_json_field "$TRAINING_RESPONSE" "training_results.successful_trainings") SALES_RECORDS=$(extract_json_field "$TRAINING_RESPONSE" "data_summary.sales_records") + echo "" echo "šŸ“Š Training Summary:" echo " Sales records: $SALES_RECORDS" echo " Products found: $TOTAL_PRODUCTS" echo " Successful trainings: $SUCCESSFUL_TRAININGS" - if [ "$TOTAL_PRODUCTS" = "0" ]; then - log_warning "No products found for training" - echo " Possible causes:" - echo " • CSV doesn't contain valid product names" - echo " • Product column is missing or malformed" - echo " • Insufficient sales data per product" - fi - - # Still test WebSocket for demonstration + # Brief WebSocket connection test log_step "4.2. Testing WebSocket endpoint (demonstration mode)" - echo "Even though training is complete, testing WebSocket connection..." + echo "Testing WebSocket connection for 10 seconds..." test_websocket_with_nodejs_builtin "$TENANT_ID" "$WEBSOCKET_JOB_ID" "10" else - # Training is in progress - monitor with WebSocket - log_step "4.2. Connecting to WebSocket for real-time progress monitoring" - test_websocket_with_nodejs_builtin "$TENANT_ID" "$WEBSOCKET_JOB_ID" "$WS_TEST_DURATION" + # Training is in progress - use smart monitoring + log_step "4.2. Starting smart WebSocket monitoring" + echo " Strategy: Monitor until job completion" + echo " Maximum wait time: ${WS_TEST_DURATION}s (safety timeout)" + echo " Will automatically close when training completes" + echo "" + + # Use enhanced monitoring with longer timeout for real training + local SMART_DURATION=$WS_TEST_DURATION + + # Estimate duration based on data size (optional enhancement) + if [ -n "$SALES_RECORDS" ] && [ "$SALES_RECORDS" -gt 1000 ]; then + # For large datasets, extend timeout + SMART_DURATION=$((WS_TEST_DURATION * 2)) + echo " šŸ“Š Large dataset detected ($SALES_RECORDS records)" + echo " šŸ• Extended timeout to ${SMART_DURATION}s for thorough training" + fi + + test_websocket_with_nodejs_builtin "$TENANT_ID" "$WEBSOCKET_JOB_ID" "$SMART_DURATION" fi else @@ -786,41 +827,6 @@ services_check() { done } -check_websocket_prerequisites() { - echo -e "${PURPLE}šŸ” Checking WebSocket testing prerequisites...${NC}" - - # Check for websocat - if command -v websocat >/dev/null 2>&1; then - log_success "websocat found - will use for WebSocket testing" - return 0 - fi - - # Check for Node.js - if command -v node >/dev/null 2>&1; then - local node_version=$(node --version 2>/dev/null || echo "unknown") - log_success "Node.js found ($node_version) - will use for WebSocket testing" - - # Check if ws module is available (try to require it) - if node -e "require('ws')" 2>/dev/null; then - log_success "Node.js 'ws' module available" - else - log_warning "Node.js 'ws' module not found" - echo " Install with: npm install -g ws" - echo " Will attempt to use built-in functionality..." - fi - return 0 - fi - - log_warning "Neither websocat nor Node.js found" - echo " WebSocket testing will use HTTP polling fallback" - echo " For better testing, install one of:" - echo " • websocat: cargo install websocat" - echo " • Node.js: https://nodejs.org/" - - return 1 -} - - services_check echo "" @@ -1167,7 +1173,7 @@ echo "" # STEP 4: MODEL TRAINING (ONBOARDING PAGE STEP 4) # ================================================================= -check_websocket_prerequisites +test_websocket_connection enhanced_training_step_with_completion_check