Merge pull request #2 from ualsweb/claude/debug-onboarding-training-step-011CUpmWixCPTKKW2re8qJ3A
Fix multiple critical bugs in onboarding training step
This commit is contained in:
@@ -8,6 +8,17 @@ import { useMutation, useQuery, useQueryClient, UseQueryOptions, UseMutationOpti
|
||||
import { trainingService } from '../services/training';
|
||||
import { ApiError, apiClient } from '../client/apiClient';
|
||||
import { useAuthStore } from '../../stores/auth.store';
|
||||
import {
|
||||
HTTP_POLLING_INTERVAL_MS,
|
||||
HTTP_POLLING_DEBOUNCE_MS,
|
||||
WEBSOCKET_HEARTBEAT_INTERVAL_MS,
|
||||
WEBSOCKET_MAX_RECONNECT_ATTEMPTS,
|
||||
WEBSOCKET_RECONNECT_INITIAL_DELAY_MS,
|
||||
WEBSOCKET_RECONNECT_MAX_DELAY_MS,
|
||||
PROGRESS_DATA_ANALYSIS,
|
||||
PROGRESS_TRAINING_RANGE_START,
|
||||
PROGRESS_TRAINING_RANGE_END
|
||||
} from '../../constants/training';
|
||||
import type {
|
||||
TrainingJobRequest,
|
||||
TrainingJobResponse,
|
||||
@@ -53,14 +64,32 @@ export const useTrainingJobStatus = (
|
||||
}
|
||||
) => {
|
||||
const { isWebSocketConnected, ...queryOptions } = options || {};
|
||||
const [enablePolling, setEnablePolling] = React.useState(false);
|
||||
|
||||
// Completely disable the query when WebSocket is connected
|
||||
const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected;
|
||||
// Debounce HTTP polling activation: wait after WebSocket disconnects
|
||||
// This prevents race conditions where both WebSocket and HTTP are briefly active
|
||||
React.useEffect(() => {
|
||||
if (!isWebSocketConnected) {
|
||||
const debounceTimer = setTimeout(() => {
|
||||
setEnablePolling(true);
|
||||
console.log(`🔄 HTTP polling enabled after ${HTTP_POLLING_DEBOUNCE_MS}ms debounce (WebSocket disconnected)`);
|
||||
}, HTTP_POLLING_DEBOUNCE_MS);
|
||||
|
||||
return () => clearTimeout(debounceTimer);
|
||||
} else {
|
||||
setEnablePolling(false);
|
||||
console.log('❌ HTTP polling disabled (WebSocket connected)');
|
||||
}
|
||||
}, [isWebSocketConnected]);
|
||||
|
||||
// Completely disable the query when WebSocket is connected or during debounce period
|
||||
const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected && enablePolling;
|
||||
|
||||
console.log('🔄 Training status query:', {
|
||||
tenantId: !!tenantId,
|
||||
jobId: !!jobId,
|
||||
isWebSocketConnected,
|
||||
enablePolling,
|
||||
queryEnabled: isEnabled
|
||||
});
|
||||
|
||||
@@ -85,8 +114,8 @@ export const useTrainingJobStatus = (
|
||||
return false; // Stop polling when training is done
|
||||
}
|
||||
|
||||
console.log('📊 HTTP fallback polling active (WebSocket disconnected) - 5s interval');
|
||||
return 5000; // Poll every 5 seconds while training (fallback when WebSocket unavailable)
|
||||
console.log(`📊 HTTP fallback polling active (WebSocket disconnected) - ${HTTP_POLLING_INTERVAL_MS}ms interval`);
|
||||
return HTTP_POLLING_INTERVAL_MS; // Poll while training (fallback when WebSocket unavailable)
|
||||
} : false, // Completely disable interval when WebSocket connected
|
||||
staleTime: 1000, // Consider data stale after 1 second
|
||||
retry: (failureCount, error) => {
|
||||
@@ -298,7 +327,7 @@ export const useTrainingWebSocket = (
|
||||
let reconnectTimer: NodeJS.Timeout | null = null;
|
||||
let isManuallyDisconnected = false;
|
||||
let reconnectAttempts = 0;
|
||||
const maxReconnectAttempts = 3;
|
||||
const maxReconnectAttempts = WEBSOCKET_MAX_RECONNECT_ATTEMPTS;
|
||||
|
||||
const connect = async () => {
|
||||
try {
|
||||
@@ -349,70 +378,49 @@ export const useTrainingWebSocket = (
|
||||
console.warn('Failed to request status on connection:', e);
|
||||
}
|
||||
|
||||
// Helper function to check if tokens represent different auth sessions
|
||||
const isTokenSessionDifferent = (oldToken: string, newToken: string): boolean => {
|
||||
// Helper function to check if tokens represent different auth users/sessions
|
||||
const isNewAuthSession = (oldToken: string, newToken: string): boolean => {
|
||||
if (!oldToken || !newToken) return !!oldToken !== !!newToken;
|
||||
|
||||
try {
|
||||
const oldPayload = JSON.parse(atob(oldToken.split('.')[1]));
|
||||
const newPayload = JSON.parse(atob(newToken.split('.')[1]));
|
||||
|
||||
// Compare by issued timestamp (iat) - different iat means new auth session
|
||||
return oldPayload.iat !== newPayload.iat;
|
||||
// Compare by user ID - different user means new auth session
|
||||
// If user_id is same, it's just a token refresh, no need to reconnect
|
||||
return oldPayload.user_id !== newPayload.user_id ||
|
||||
oldPayload.sub !== newPayload.sub;
|
||||
} catch (e) {
|
||||
console.warn('Failed to parse token for session comparison, falling back to string comparison:', e);
|
||||
return oldToken !== newToken;
|
||||
console.warn('Failed to parse token for session comparison:', e);
|
||||
// On parse error, don't reconnect (assume same session)
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// Set up periodic ping and intelligent token refresh detection
|
||||
// Set up periodic ping and check for auth session changes
|
||||
const heartbeatInterval = setInterval(async () => {
|
||||
if (ws?.readyState === WebSocket.OPEN && !isManuallyDisconnected) {
|
||||
try {
|
||||
// Check token validity (this may refresh if needed)
|
||||
const currentToken = await apiClient.ensureValidToken();
|
||||
|
||||
// Enhanced token change detection with detailed logging
|
||||
const tokenStringChanged = currentToken !== effectiveToken;
|
||||
const tokenSessionChanged = currentToken && effectiveToken ?
|
||||
isTokenSessionDifferent(effectiveToken, currentToken) : tokenStringChanged;
|
||||
|
||||
console.log('🔍 WebSocket token validation check:', {
|
||||
hasCurrentToken: !!currentToken,
|
||||
hasEffectiveToken: !!effectiveToken,
|
||||
tokenStringChanged,
|
||||
tokenSessionChanged,
|
||||
currentTokenPreview: currentToken ? `${currentToken.slice(0, 20)}...${currentToken.slice(-10)}` : 'null',
|
||||
effectiveTokenPreview: effectiveToken ? `${effectiveToken.slice(0, 20)}...${effectiveToken.slice(-10)}` : 'null'
|
||||
});
|
||||
|
||||
// Only reconnect if we have a genuine session change (different iat)
|
||||
if (tokenSessionChanged) {
|
||||
console.log('🔄 Token session changed - reconnecting WebSocket with new session token');
|
||||
console.log('📊 Session change details:', {
|
||||
reason: !currentToken ? 'token removed' :
|
||||
!effectiveToken ? 'token added' : 'new auth session',
|
||||
oldTokenIat: effectiveToken ? (() => {
|
||||
try { return JSON.parse(atob(effectiveToken.split('.')[1])).iat; } catch { return 'parse-error'; }
|
||||
})() : 'N/A',
|
||||
newTokenIat: currentToken ? (() => {
|
||||
try { return JSON.parse(atob(currentToken.split('.')[1])).iat; } catch { return 'parse-error'; }
|
||||
})() : 'N/A'
|
||||
});
|
||||
|
||||
// Close current connection and trigger reconnection with new token
|
||||
ws?.close(1000, 'Token session changed - reconnecting');
|
||||
// Only reconnect if user changed (new auth session)
|
||||
if (currentToken && effectiveToken && isNewAuthSession(effectiveToken, currentToken)) {
|
||||
console.log('🔄 Auth session changed (different user) - reconnecting WebSocket');
|
||||
ws?.close(1000, 'Auth session changed - reconnecting');
|
||||
clearInterval(heartbeatInterval);
|
||||
return;
|
||||
} else if (tokenStringChanged) {
|
||||
console.log('ℹ️ Token string changed but same session - continuing with current connection');
|
||||
// Update effective token reference for future comparisons
|
||||
}
|
||||
|
||||
// Token may have been refreshed but it's the same user - continue
|
||||
if (currentToken && currentToken !== effectiveToken) {
|
||||
console.log('ℹ️ Token refreshed (same user) - updating reference');
|
||||
effectiveToken = currentToken;
|
||||
}
|
||||
|
||||
console.log('✅ Token validated during heartbeat - same session');
|
||||
// Send ping
|
||||
ws?.send('ping');
|
||||
console.log('💓 Sent ping to server (token session validated)');
|
||||
console.log('💓 Sent ping to server');
|
||||
} catch (e) {
|
||||
console.warn('Failed to send ping or validate token:', e);
|
||||
clearInterval(heartbeatInterval);
|
||||
@@ -420,7 +428,7 @@ export const useTrainingWebSocket = (
|
||||
} else {
|
||||
clearInterval(heartbeatInterval);
|
||||
}
|
||||
}, 30000); // Check every 30 seconds for token refresh and send ping
|
||||
}, WEBSOCKET_HEARTBEAT_INTERVAL_MS); // Check for auth changes and send ping
|
||||
|
||||
// Store interval for cleanup
|
||||
(ws as any).heartbeatInterval = heartbeatInterval;
|
||||
@@ -449,7 +457,8 @@ export const useTrainingWebSocket = (
|
||||
if (initialData.type === 'product_completed') {
|
||||
const productsCompleted = initialEventData.products_completed || 0;
|
||||
const totalProducts = initialEventData.total_products || 1;
|
||||
initialProgress = 20 + Math.floor((productsCompleted / totalProducts) * 60);
|
||||
const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS;
|
||||
initialProgress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth);
|
||||
console.log('📦 Product training completed in initial state',
|
||||
`${productsCompleted}/${totalProducts}`,
|
||||
`progress: ${initialProgress}%`);
|
||||
@@ -486,8 +495,9 @@ export const useTrainingWebSocket = (
|
||||
const productsCompleted = eventData.products_completed || 0;
|
||||
const totalProducts = eventData.total_products || 1;
|
||||
|
||||
// Calculate progress: 20% base + (completed/total * 60%)
|
||||
progress = 20 + Math.floor((productsCompleted / totalProducts) * 60);
|
||||
// Calculate progress: DATA_ANALYSIS% base + (completed/total * (TRAINING_RANGE_END - DATA_ANALYSIS)%)
|
||||
const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS;
|
||||
progress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth);
|
||||
|
||||
console.log('📦 Product training completed',
|
||||
`${productsCompleted}/${totalProducts}`,
|
||||
@@ -585,8 +595,8 @@ export const useTrainingWebSocket = (
|
||||
// Detailed logging for different close codes
|
||||
switch (event.code) {
|
||||
case 1000:
|
||||
if (event.reason === 'Token refreshed - reconnecting') {
|
||||
console.log('🔄 WebSocket closed for token refresh - will reconnect immediately');
|
||||
if (event.reason === 'Auth session changed - reconnecting') {
|
||||
console.log('🔄 WebSocket closed for auth session change - will reconnect immediately');
|
||||
} else {
|
||||
console.log('🔒 WebSocket closed normally');
|
||||
}
|
||||
@@ -604,18 +614,21 @@ export const useTrainingWebSocket = (
|
||||
console.log(`❓ WebSocket closed with code ${event.code}`);
|
||||
}
|
||||
|
||||
// Handle token refresh reconnection (immediate reconnect)
|
||||
if (event.code === 1000 && event.reason === 'Token refreshed - reconnecting') {
|
||||
console.log('🔄 Reconnecting immediately due to token refresh...');
|
||||
// Handle auth session change reconnection (immediate reconnect)
|
||||
if (event.code === 1000 && event.reason === 'Auth session changed - reconnecting') {
|
||||
console.log('🔄 Reconnecting immediately due to auth session change...');
|
||||
reconnectTimer = setTimeout(() => {
|
||||
connect(); // Reconnect immediately with fresh token
|
||||
}, 1000); // Short delay to allow cleanup
|
||||
connect(); // Reconnect immediately with new session token
|
||||
}, WEBSOCKET_RECONNECT_INITIAL_DELAY_MS); // Short delay to allow cleanup
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to reconnect if not manually disconnected and haven't exceeded max attempts
|
||||
if (!isManuallyDisconnected && event.code !== 1000 && reconnectAttempts < maxReconnectAttempts) {
|
||||
const delay = Math.min(1000 * Math.pow(2, reconnectAttempts), 10000); // Exponential backoff, max 10s
|
||||
const delay = Math.min(
|
||||
WEBSOCKET_RECONNECT_INITIAL_DELAY_MS * Math.pow(2, reconnectAttempts),
|
||||
WEBSOCKET_RECONNECT_MAX_DELAY_MS
|
||||
); // Exponential backoff
|
||||
console.log(`🔄 Attempting to reconnect WebSocket in ${delay/1000}s... (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})`);
|
||||
|
||||
reconnectTimer = setTimeout(() => {
|
||||
|
||||
@@ -5,6 +5,11 @@ import { Button } from '../../../ui/Button';
|
||||
import { useCurrentTenant } from '../../../../stores/tenant.store';
|
||||
import { useCreateTrainingJob, useTrainingWebSocket, useTrainingJobStatus } from '../../../../api/hooks/training';
|
||||
import { Info } from 'lucide-react';
|
||||
import {
|
||||
TRAINING_SKIP_OPTION_DELAY_MS,
|
||||
TRAINING_COMPLETION_DELAY_MS,
|
||||
SKIP_TIMER_CHECK_INTERVAL_MS
|
||||
} from '../../../../constants/training';
|
||||
|
||||
interface MLTrainingStepProps {
|
||||
onNext: () => void;
|
||||
@@ -38,16 +43,16 @@ export const MLTrainingStep: React.FC<MLTrainingStepProps> = ({
|
||||
const currentTenant = useCurrentTenant();
|
||||
const createTrainingJob = useCreateTrainingJob();
|
||||
|
||||
// Check if training has been running for more than 2 minutes
|
||||
// Check if training has been running for more than the skip delay threshold
|
||||
useEffect(() => {
|
||||
if (trainingStartTime && isTraining && !showSkipOption) {
|
||||
const checkTimer = setInterval(() => {
|
||||
const elapsedTime = (Date.now() - trainingStartTime) / 1000; // in seconds
|
||||
if (elapsedTime > 120) { // 2 minutes
|
||||
if (elapsedTime > TRAINING_SKIP_OPTION_DELAY_MS / 1000) {
|
||||
setShowSkipOption(true);
|
||||
clearInterval(checkTimer);
|
||||
}
|
||||
}, 5000); // Check every 5 seconds
|
||||
}, SKIP_TIMER_CHECK_INTERVAL_MS);
|
||||
|
||||
return () => clearInterval(checkTimer);
|
||||
}
|
||||
@@ -72,14 +77,14 @@ export const MLTrainingStep: React.FC<MLTrainingStepProps> = ({
|
||||
message: 'Entrenamiento completado exitosamente'
|
||||
});
|
||||
setIsTraining(false);
|
||||
|
||||
|
||||
setTimeout(() => {
|
||||
onComplete({
|
||||
jobId: jobId,
|
||||
success: true,
|
||||
message: 'Modelo entrenado correctamente'
|
||||
});
|
||||
}, 2000);
|
||||
}, TRAINING_COMPLETION_DELAY_MS);
|
||||
}, [onComplete, jobId]);
|
||||
|
||||
const handleError = useCallback((data: any) => {
|
||||
@@ -147,7 +152,7 @@ export const MLTrainingStep: React.FC<MLTrainingStepProps> = ({
|
||||
message: 'Modelo entrenado correctamente',
|
||||
detectedViaPolling: true
|
||||
});
|
||||
}, 2000);
|
||||
}, TRAINING_COMPLETION_DELAY_MS);
|
||||
} else if (jobStatus.status === 'failed') {
|
||||
console.log(`❌ Training failure detected (source: ${isConnected ? 'WebSocket' : 'HTTP polling'})`);
|
||||
setError('Error detectado durante el entrenamiento (verificación de estado)');
|
||||
@@ -169,13 +174,15 @@ export const MLTrainingStep: React.FC<MLTrainingStepProps> = ({
|
||||
}
|
||||
}, [jobStatus, jobId, trainingProgress?.stage, onComplete, isConnected]);
|
||||
|
||||
// Auto-trigger training when component mounts
|
||||
// Auto-trigger training when component mounts (run once)
|
||||
const hasAutoStarted = React.useRef(false);
|
||||
useEffect(() => {
|
||||
if (currentTenant?.id && !isTraining && !trainingProgress && !error) {
|
||||
if (currentTenant?.id && !hasAutoStarted.current && !isTraining && !trainingProgress && !error) {
|
||||
console.log('🚀 Auto-starting ML training for tenant:', currentTenant.id);
|
||||
hasAutoStarted.current = true;
|
||||
handleStartTraining();
|
||||
}
|
||||
}, [currentTenant?.id]); // Only run when tenant is available
|
||||
}, [currentTenant?.id, isTraining, trainingProgress, error]); // Include all checked dependencies
|
||||
|
||||
const handleStartTraining = async () => {
|
||||
if (!currentTenant?.id) {
|
||||
|
||||
25
frontend/src/constants/training.ts
Normal file
25
frontend/src/constants/training.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
/**
|
||||
* Training Progress Constants
|
||||
* Centralized constants for training UI and behavior
|
||||
*/
|
||||
|
||||
// Time Intervals (milliseconds)
|
||||
export const TRAINING_SKIP_OPTION_DELAY_MS = 120000; // 2 minutes
|
||||
export const TRAINING_COMPLETION_DELAY_MS = 2000; // 2 seconds
|
||||
export const HTTP_POLLING_INTERVAL_MS = 5000; // 5 seconds
|
||||
export const HTTP_POLLING_DEBOUNCE_MS = 5000; // 5 seconds
|
||||
export const WEBSOCKET_HEARTBEAT_INTERVAL_MS = 30000; // 30 seconds
|
||||
|
||||
// WebSocket Configuration
|
||||
export const WEBSOCKET_MAX_RECONNECT_ATTEMPTS = 3;
|
||||
export const WEBSOCKET_RECONNECT_INITIAL_DELAY_MS = 1000; // 1 second
|
||||
export const WEBSOCKET_RECONNECT_MAX_DELAY_MS = 10000; // 10 seconds
|
||||
|
||||
// Progress Milestones
|
||||
export const PROGRESS_DATA_ANALYSIS = 20;
|
||||
export const PROGRESS_TRAINING_RANGE_START = 20;
|
||||
export const PROGRESS_TRAINING_RANGE_END = 80;
|
||||
export const PROGRESS_COMPLETED = 100;
|
||||
|
||||
// Skip Timer Check Interval
|
||||
export const SKIP_TIMER_CHECK_INTERVAL_MS = 5000; // 5 seconds
|
||||
@@ -189,15 +189,8 @@ async def start_training_job(
|
||||
# Calculate estimated completion time
|
||||
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
|
||||
|
||||
# Publish training.started event immediately so WebSocket clients
|
||||
# have initial state when they connect
|
||||
await publish_training_started(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=0, # Will be updated when actual training starts
|
||||
estimated_duration_minutes=estimated_duration_minutes,
|
||||
estimated_completion_time=estimated_completion_time.isoformat()
|
||||
)
|
||||
# Note: training.started event will be published by the trainer with accurate product count
|
||||
# We don't publish here to avoid duplicate events
|
||||
|
||||
# Add enhanced background task
|
||||
background_tasks.add_task(
|
||||
@@ -401,11 +394,6 @@ async def execute_training_job_background(
|
||||
# Failure event is published by the training service
|
||||
await publish_training_failed(job_id, tenant_id, str(training_error))
|
||||
|
||||
except Exception as background_error:
|
||||
logger.error("Critical error in enhanced background training job",
|
||||
job_id=job_id,
|
||||
error=str(background_error))
|
||||
|
||||
finally:
|
||||
logger.info("Enhanced background training job cleanup completed",
|
||||
job_id=job_id)
|
||||
|
||||
35
services/training/app/core/training_constants.py
Normal file
35
services/training/app/core/training_constants.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Training Progress Constants
|
||||
Centralized constants for training progress tracking and timing
|
||||
"""
|
||||
|
||||
# Progress Milestones (percentage)
|
||||
PROGRESS_STARTED = 0
|
||||
PROGRESS_DATA_VALIDATION = 10
|
||||
PROGRESS_DATA_ANALYSIS = 20
|
||||
PROGRESS_DATA_PREPARATION_COMPLETE = 30
|
||||
PROGRESS_ML_TRAINING_START = 40
|
||||
PROGRESS_TRAINING_COMPLETE = 85
|
||||
PROGRESS_STORING_MODELS = 92
|
||||
PROGRESS_STORING_METRICS = 94
|
||||
PROGRESS_COMPLETED = 100
|
||||
|
||||
# Progress Ranges
|
||||
PROGRESS_TRAINING_RANGE_START = 20 # After data analysis
|
||||
PROGRESS_TRAINING_RANGE_END = 80 # Before finalization
|
||||
PROGRESS_TRAINING_RANGE_WIDTH = PROGRESS_TRAINING_RANGE_END - PROGRESS_TRAINING_RANGE_START # 60%
|
||||
|
||||
# Time Limits and Intervals (seconds)
|
||||
MAX_ESTIMATED_TIME_REMAINING_SECONDS = 1800 # 30 minutes
|
||||
WEBSOCKET_HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
WEBSOCKET_RECONNECT_MAX_ATTEMPTS = 3
|
||||
WEBSOCKET_RECONNECT_INITIAL_DELAY_SECONDS = 1
|
||||
WEBSOCKET_RECONNECT_MAX_DELAY_SECONDS = 10
|
||||
|
||||
# Training Timeouts (seconds)
|
||||
TRAINING_SKIP_OPTION_DELAY_SECONDS = 120 # 2 minutes
|
||||
HTTP_POLLING_INTERVAL_MS = 5000 # 5 seconds
|
||||
HTTP_POLLING_DEBOUNCE_MS = 5000 # 5 seconds before enabling after WebSocket disconnect
|
||||
|
||||
# Frontend Display
|
||||
TRAINING_COMPLETION_DELAY_MS = 2000 # Delay before navigating after completion
|
||||
@@ -6,7 +6,7 @@ Main ML pipeline coordinator using repository pattern for data access and depend
|
||||
from typing import Dict, List, Any, Optional
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
import time
|
||||
@@ -187,7 +187,10 @@ class EnhancedBakeryMLTrainer:
|
||||
|
||||
# Event 2: Data Analysis (20%)
|
||||
# Recalculate time remaining based on elapsed time
|
||||
elapsed_seconds = (datetime.now(timezone.utc) - repos['training_log']._get_start_time(job_id) if hasattr(repos['training_log'], '_get_start_time') else 0) or 0
|
||||
start_time = await repos['training_log'].get_start_time(job_id)
|
||||
elapsed_seconds = 0
|
||||
if start_time:
|
||||
elapsed_seconds = int((datetime.now(timezone.utc) - start_time).total_seconds())
|
||||
|
||||
# Estimate remaining time: we've done ~20% of work (data analysis)
|
||||
# Remaining 80% includes training all products
|
||||
@@ -285,7 +288,8 @@ class EnhancedBakeryMLTrainer:
|
||||
except Exception as e:
|
||||
logger.error("Enhanced ML training pipeline failed",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
|
||||
# Publish training failed event
|
||||
await publish_training_failed(job_id, tenant_id, str(e))
|
||||
@@ -397,7 +401,8 @@ class EnhancedBakeryMLTrainer:
|
||||
logger.error("Single product model training failed",
|
||||
job_id=job_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def _serialize_scalers(self, scalers: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@@ -329,4 +329,17 @@ class TrainingLogRepository(TrainingBaseRepository):
|
||||
"min_duration_minutes": 0.0,
|
||||
"max_duration_minutes": 0.0,
|
||||
"completed_jobs_with_duration": 0
|
||||
}
|
||||
}
|
||||
|
||||
async def get_start_time(self, job_id: str) -> Optional[datetime]:
|
||||
"""Get the start time for a training job"""
|
||||
try:
|
||||
log_entry = await self.get_by_job_id(job_id)
|
||||
if log_entry and log_entry.start_time:
|
||||
return log_entry.start_time
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get start time",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
return None
|
||||
@@ -10,6 +10,11 @@ from datetime import datetime, timezone
|
||||
|
||||
from app.services.training_events import publish_product_training_completed
|
||||
from app.utils.time_estimation import calculate_estimated_completion_time
|
||||
from app.core.training_constants import (
|
||||
PROGRESS_TRAINING_RANGE_START,
|
||||
PROGRESS_TRAINING_RANGE_END,
|
||||
PROGRESS_TRAINING_RANGE_WIDTH
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -34,8 +39,8 @@ class ParallelProductProgressTracker:
|
||||
self.start_time = datetime.now(timezone.utc)
|
||||
|
||||
# Calculate progress increment per product
|
||||
# 60% of total progress (from 20% to 80%) divided by number of products
|
||||
self.progress_per_product = 60 / total_products if total_products > 0 else 0
|
||||
# Training range (from PROGRESS_TRAINING_RANGE_START to PROGRESS_TRAINING_RANGE_END) divided by number of products
|
||||
self.progress_per_product = PROGRESS_TRAINING_RANGE_WIDTH / total_products if total_products > 0 else 0
|
||||
|
||||
logger.info("ParallelProductProgressTracker initialized",
|
||||
job_id=job_id,
|
||||
@@ -80,9 +85,9 @@ class ParallelProductProgressTracker:
|
||||
estimated_completion_time=estimated_completion_time
|
||||
)
|
||||
|
||||
# Calculate overall progress (20% base + progress from completed products)
|
||||
# Calculate overall progress (PROGRESS_TRAINING_RANGE_START% base + progress from completed products)
|
||||
# This calculation is done on the frontend/consumer side based on the event data
|
||||
overall_progress = 20 + int((current_progress / self.total_products) * 60)
|
||||
overall_progress = PROGRESS_TRAINING_RANGE_START + int((current_progress / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
|
||||
|
||||
logger.info("Product training completed",
|
||||
job_id=self.job_id,
|
||||
@@ -99,5 +104,5 @@ class ParallelProductProgressTracker:
|
||||
return {
|
||||
"products_completed": self.products_completed,
|
||||
"total_products": self.total_products,
|
||||
"progress_percentage": 20 + int((self.products_completed / self.total_products) * 60)
|
||||
"progress_percentage": PROGRESS_TRAINING_RANGE_START + int((self.products_completed / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
|
||||
}
|
||||
|
||||
@@ -135,6 +135,61 @@ async def publish_data_analysis(
|
||||
return success
|
||||
|
||||
|
||||
async def publish_training_progress(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
progress: int,
|
||||
current_step: str,
|
||||
step_details: Optional[str] = None,
|
||||
estimated_time_remaining_seconds: Optional[int] = None,
|
||||
estimated_completion_time: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Generic Training Progress Event (for any progress percentage)
|
||||
|
||||
Args:
|
||||
job_id: Training job identifier
|
||||
tenant_id: Tenant identifier
|
||||
progress: Progress percentage (0-100)
|
||||
current_step: Current step name
|
||||
step_details: Details about the current step
|
||||
estimated_time_remaining_seconds: Estimated time remaining in seconds
|
||||
estimated_completion_time: ISO timestamp of estimated completion
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.progress",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": progress,
|
||||
"current_step": current_step,
|
||||
"step_details": step_details or current_step,
|
||||
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
|
||||
"estimated_completion_time": estimated_completion_time
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.progress",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published training progress event",
|
||||
job_id=job_id,
|
||||
progress=progress,
|
||||
current_step=current_step)
|
||||
else:
|
||||
logger.error("Failed to publish training progress event",
|
||||
job_id=job_id,
|
||||
progress=progress)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_product_training_completed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
|
||||
@@ -16,6 +16,15 @@ import pandas as pd
|
||||
from app.ml.trainer import EnhancedBakeryMLTrainer
|
||||
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
|
||||
from app.services.training_orchestrator import TrainingDataOrchestrator
|
||||
from app.core.training_constants import (
|
||||
PROGRESS_DATA_VALIDATION,
|
||||
PROGRESS_DATA_PREPARATION_COMPLETE,
|
||||
PROGRESS_ML_TRAINING_START,
|
||||
PROGRESS_TRAINING_COMPLETE,
|
||||
PROGRESS_STORING_MODELS,
|
||||
PROGRESS_STORING_METRICS,
|
||||
MAX_ESTIMATED_TIME_REMAINING_SECONDS
|
||||
)
|
||||
|
||||
# Import repositories
|
||||
from app.repositories import (
|
||||
@@ -187,7 +196,7 @@ class EnhancedTrainingService:
|
||||
# Step 1: Prepare training dataset (includes sales data validation)
|
||||
logger.info("Step 1: Preparing and aligning training data (with validation)")
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 10, "data_validation", "running"
|
||||
job_id, PROGRESS_DATA_VALIDATION, "data_validation", "running"
|
||||
)
|
||||
|
||||
# Orchestrator now handles sales data validation to eliminate duplicate fetching
|
||||
@@ -204,13 +213,13 @@ class EnhancedTrainingService:
|
||||
tenant_id=tenant_id, job_id=job_id)
|
||||
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 30, "data_preparation_complete", "running"
|
||||
job_id, PROGRESS_DATA_PREPARATION_COMPLETE, "data_preparation_complete", "running"
|
||||
)
|
||||
|
||||
|
||||
# Step 2: Execute ML training pipeline
|
||||
logger.info("Step 2: Starting ML training pipeline")
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 40, "ml_training", "running"
|
||||
job_id, PROGRESS_ML_TRAINING_START, "ml_training", "running"
|
||||
)
|
||||
|
||||
training_results = await self.trainer.train_tenant_models(
|
||||
@@ -220,9 +229,19 @@ class EnhancedTrainingService:
|
||||
)
|
||||
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 85, "training_complete", "running"
|
||||
job_id, PROGRESS_TRAINING_COMPLETE, "training_complete", "running"
|
||||
)
|
||||
|
||||
|
||||
# Publish progress event (85%)
|
||||
from app.services.training_events import publish_training_progress
|
||||
await publish_training_progress(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
progress=PROGRESS_TRAINING_COMPLETE,
|
||||
current_step="Training Complete",
|
||||
step_details="All products trained successfully"
|
||||
)
|
||||
|
||||
# Step 3: Store model records using repository
|
||||
logger.info("Step 3: Storing model records")
|
||||
logger.debug("Training results structure",
|
||||
@@ -234,12 +253,30 @@ class EnhancedTrainingService:
|
||||
)
|
||||
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 92, "storing_models", "running"
|
||||
job_id, PROGRESS_STORING_MODELS, "storing_models", "running"
|
||||
)
|
||||
|
||||
# Publish progress event (92%)
|
||||
await publish_training_progress(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
progress=PROGRESS_STORING_MODELS,
|
||||
current_step="Storing Models",
|
||||
step_details=f"Saved {len(stored_models)} trained models to database"
|
||||
)
|
||||
|
||||
# Step 4: Create performance metrics
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 94, "storing_performance_metrics", "running"
|
||||
job_id, PROGRESS_STORING_METRICS, "storing_performance_metrics", "running"
|
||||
)
|
||||
|
||||
# Publish progress event (94%)
|
||||
await publish_training_progress(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
progress=PROGRESS_STORING_METRICS,
|
||||
current_step="Storing Performance Metrics",
|
||||
step_details="Saving model performance metrics"
|
||||
)
|
||||
|
||||
await self._create_performance_metrics(
|
||||
@@ -316,8 +353,9 @@ class EnhancedTrainingService:
|
||||
except Exception as e:
|
||||
logger.error("Enhanced training job failed",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
|
||||
# Mark as failed in database
|
||||
await self.training_log_repo.complete_training_log(
|
||||
job_id, error_message=str(e)
|
||||
@@ -519,12 +557,13 @@ class EnhancedTrainingService:
|
||||
if log.status == "running" and log.progress > 0 and log.start_time:
|
||||
from datetime import datetime, timezone
|
||||
elapsed_time = (datetime.now(timezone.utc) - log.start_time).total_seconds()
|
||||
if elapsed_time > 0:
|
||||
if elapsed_time > 0 and log.progress > 0: # Double-check progress is positive
|
||||
# Calculate estimated total time based on progress
|
||||
estimated_total_time = (elapsed_time / log.progress) * 100
|
||||
# Use max(log.progress, 1) as additional safety against division by zero
|
||||
estimated_total_time = (elapsed_time / max(log.progress, 1)) * 100
|
||||
estimated_time_remaining_seconds = int(estimated_total_time - elapsed_time)
|
||||
# Cap at reasonable maximum (e.g., 30 minutes)
|
||||
estimated_time_remaining_seconds = min(estimated_time_remaining_seconds, 1800)
|
||||
# Cap at reasonable maximum (30 minutes)
|
||||
estimated_time_remaining_seconds = max(0, min(estimated_time_remaining_seconds, MAX_ESTIMATED_TIME_REMAINING_SECONDS))
|
||||
|
||||
# Extract products info from results if available
|
||||
products_total = 0
|
||||
|
||||
Reference in New Issue
Block a user