Merge pull request #2 from ualsweb/claude/debug-onboarding-training-step-011CUpmWixCPTKKW2re8qJ3A

Fix multiple critical bugs in onboarding training step
This commit is contained in:
ualsweb
2025-11-05 14:05:40 +01:00
committed by GitHub
10 changed files with 291 additions and 106 deletions

View File

@@ -8,6 +8,17 @@ import { useMutation, useQuery, useQueryClient, UseQueryOptions, UseMutationOpti
import { trainingService } from '../services/training'; import { trainingService } from '../services/training';
import { ApiError, apiClient } from '../client/apiClient'; import { ApiError, apiClient } from '../client/apiClient';
import { useAuthStore } from '../../stores/auth.store'; import { useAuthStore } from '../../stores/auth.store';
import {
HTTP_POLLING_INTERVAL_MS,
HTTP_POLLING_DEBOUNCE_MS,
WEBSOCKET_HEARTBEAT_INTERVAL_MS,
WEBSOCKET_MAX_RECONNECT_ATTEMPTS,
WEBSOCKET_RECONNECT_INITIAL_DELAY_MS,
WEBSOCKET_RECONNECT_MAX_DELAY_MS,
PROGRESS_DATA_ANALYSIS,
PROGRESS_TRAINING_RANGE_START,
PROGRESS_TRAINING_RANGE_END
} from '../../constants/training';
import type { import type {
TrainingJobRequest, TrainingJobRequest,
TrainingJobResponse, TrainingJobResponse,
@@ -53,14 +64,32 @@ export const useTrainingJobStatus = (
} }
) => { ) => {
const { isWebSocketConnected, ...queryOptions } = options || {}; const { isWebSocketConnected, ...queryOptions } = options || {};
const [enablePolling, setEnablePolling] = React.useState(false);
// Completely disable the query when WebSocket is connected // Debounce HTTP polling activation: wait after WebSocket disconnects
const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected; // This prevents race conditions where both WebSocket and HTTP are briefly active
React.useEffect(() => {
if (!isWebSocketConnected) {
const debounceTimer = setTimeout(() => {
setEnablePolling(true);
console.log(`🔄 HTTP polling enabled after ${HTTP_POLLING_DEBOUNCE_MS}ms debounce (WebSocket disconnected)`);
}, HTTP_POLLING_DEBOUNCE_MS);
return () => clearTimeout(debounceTimer);
} else {
setEnablePolling(false);
console.log('❌ HTTP polling disabled (WebSocket connected)');
}
}, [isWebSocketConnected]);
// Completely disable the query when WebSocket is connected or during debounce period
const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected && enablePolling;
console.log('🔄 Training status query:', { console.log('🔄 Training status query:', {
tenantId: !!tenantId, tenantId: !!tenantId,
jobId: !!jobId, jobId: !!jobId,
isWebSocketConnected, isWebSocketConnected,
enablePolling,
queryEnabled: isEnabled queryEnabled: isEnabled
}); });
@@ -85,8 +114,8 @@ export const useTrainingJobStatus = (
return false; // Stop polling when training is done return false; // Stop polling when training is done
} }
console.log('📊 HTTP fallback polling active (WebSocket disconnected) - 5s interval'); console.log(`📊 HTTP fallback polling active (WebSocket disconnected) - ${HTTP_POLLING_INTERVAL_MS}ms interval`);
return 5000; // Poll every 5 seconds while training (fallback when WebSocket unavailable) return HTTP_POLLING_INTERVAL_MS; // Poll while training (fallback when WebSocket unavailable)
} : false, // Completely disable interval when WebSocket connected } : false, // Completely disable interval when WebSocket connected
staleTime: 1000, // Consider data stale after 1 second staleTime: 1000, // Consider data stale after 1 second
retry: (failureCount, error) => { retry: (failureCount, error) => {
@@ -298,7 +327,7 @@ export const useTrainingWebSocket = (
let reconnectTimer: NodeJS.Timeout | null = null; let reconnectTimer: NodeJS.Timeout | null = null;
let isManuallyDisconnected = false; let isManuallyDisconnected = false;
let reconnectAttempts = 0; let reconnectAttempts = 0;
const maxReconnectAttempts = 3; const maxReconnectAttempts = WEBSOCKET_MAX_RECONNECT_ATTEMPTS;
const connect = async () => { const connect = async () => {
try { try {
@@ -349,70 +378,49 @@ export const useTrainingWebSocket = (
console.warn('Failed to request status on connection:', e); console.warn('Failed to request status on connection:', e);
} }
// Helper function to check if tokens represent different auth sessions // Helper function to check if tokens represent different auth users/sessions
const isTokenSessionDifferent = (oldToken: string, newToken: string): boolean => { const isNewAuthSession = (oldToken: string, newToken: string): boolean => {
if (!oldToken || !newToken) return !!oldToken !== !!newToken; if (!oldToken || !newToken) return !!oldToken !== !!newToken;
try { try {
const oldPayload = JSON.parse(atob(oldToken.split('.')[1])); const oldPayload = JSON.parse(atob(oldToken.split('.')[1]));
const newPayload = JSON.parse(atob(newToken.split('.')[1])); const newPayload = JSON.parse(atob(newToken.split('.')[1]));
// Compare by issued timestamp (iat) - different iat means new auth session // Compare by user ID - different user means new auth session
return oldPayload.iat !== newPayload.iat; // If user_id is same, it's just a token refresh, no need to reconnect
return oldPayload.user_id !== newPayload.user_id ||
oldPayload.sub !== newPayload.sub;
} catch (e) { } catch (e) {
console.warn('Failed to parse token for session comparison, falling back to string comparison:', e); console.warn('Failed to parse token for session comparison:', e);
return oldToken !== newToken; // On parse error, don't reconnect (assume same session)
return false;
} }
}; };
// Set up periodic ping and intelligent token refresh detection // Set up periodic ping and check for auth session changes
const heartbeatInterval = setInterval(async () => { const heartbeatInterval = setInterval(async () => {
if (ws?.readyState === WebSocket.OPEN && !isManuallyDisconnected) { if (ws?.readyState === WebSocket.OPEN && !isManuallyDisconnected) {
try { try {
// Check token validity (this may refresh if needed) // Check token validity (this may refresh if needed)
const currentToken = await apiClient.ensureValidToken(); const currentToken = await apiClient.ensureValidToken();
// Enhanced token change detection with detailed logging // Only reconnect if user changed (new auth session)
const tokenStringChanged = currentToken !== effectiveToken; if (currentToken && effectiveToken && isNewAuthSession(effectiveToken, currentToken)) {
const tokenSessionChanged = currentToken && effectiveToken ? console.log('🔄 Auth session changed (different user) - reconnecting WebSocket');
isTokenSessionDifferent(effectiveToken, currentToken) : tokenStringChanged; ws?.close(1000, 'Auth session changed - reconnecting');
console.log('🔍 WebSocket token validation check:', {
hasCurrentToken: !!currentToken,
hasEffectiveToken: !!effectiveToken,
tokenStringChanged,
tokenSessionChanged,
currentTokenPreview: currentToken ? `${currentToken.slice(0, 20)}...${currentToken.slice(-10)}` : 'null',
effectiveTokenPreview: effectiveToken ? `${effectiveToken.slice(0, 20)}...${effectiveToken.slice(-10)}` : 'null'
});
// Only reconnect if we have a genuine session change (different iat)
if (tokenSessionChanged) {
console.log('🔄 Token session changed - reconnecting WebSocket with new session token');
console.log('📊 Session change details:', {
reason: !currentToken ? 'token removed' :
!effectiveToken ? 'token added' : 'new auth session',
oldTokenIat: effectiveToken ? (() => {
try { return JSON.parse(atob(effectiveToken.split('.')[1])).iat; } catch { return 'parse-error'; }
})() : 'N/A',
newTokenIat: currentToken ? (() => {
try { return JSON.parse(atob(currentToken.split('.')[1])).iat; } catch { return 'parse-error'; }
})() : 'N/A'
});
// Close current connection and trigger reconnection with new token
ws?.close(1000, 'Token session changed - reconnecting');
clearInterval(heartbeatInterval); clearInterval(heartbeatInterval);
return; return;
} else if (tokenStringChanged) { }
console.log(' Token string changed but same session - continuing with current connection');
// Update effective token reference for future comparisons // Token may have been refreshed but it's the same user - continue
if (currentToken && currentToken !== effectiveToken) {
console.log(' Token refreshed (same user) - updating reference');
effectiveToken = currentToken; effectiveToken = currentToken;
} }
console.log('✅ Token validated during heartbeat - same session'); // Send ping
ws?.send('ping'); ws?.send('ping');
console.log('💓 Sent ping to server (token session validated)'); console.log('💓 Sent ping to server');
} catch (e) { } catch (e) {
console.warn('Failed to send ping or validate token:', e); console.warn('Failed to send ping or validate token:', e);
clearInterval(heartbeatInterval); clearInterval(heartbeatInterval);
@@ -420,7 +428,7 @@ export const useTrainingWebSocket = (
} else { } else {
clearInterval(heartbeatInterval); clearInterval(heartbeatInterval);
} }
}, 30000); // Check every 30 seconds for token refresh and send ping }, WEBSOCKET_HEARTBEAT_INTERVAL_MS); // Check for auth changes and send ping
// Store interval for cleanup // Store interval for cleanup
(ws as any).heartbeatInterval = heartbeatInterval; (ws as any).heartbeatInterval = heartbeatInterval;
@@ -449,7 +457,8 @@ export const useTrainingWebSocket = (
if (initialData.type === 'product_completed') { if (initialData.type === 'product_completed') {
const productsCompleted = initialEventData.products_completed || 0; const productsCompleted = initialEventData.products_completed || 0;
const totalProducts = initialEventData.total_products || 1; const totalProducts = initialEventData.total_products || 1;
initialProgress = 20 + Math.floor((productsCompleted / totalProducts) * 60); const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS;
initialProgress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth);
console.log('📦 Product training completed in initial state', console.log('📦 Product training completed in initial state',
`${productsCompleted}/${totalProducts}`, `${productsCompleted}/${totalProducts}`,
`progress: ${initialProgress}%`); `progress: ${initialProgress}%`);
@@ -486,8 +495,9 @@ export const useTrainingWebSocket = (
const productsCompleted = eventData.products_completed || 0; const productsCompleted = eventData.products_completed || 0;
const totalProducts = eventData.total_products || 1; const totalProducts = eventData.total_products || 1;
// Calculate progress: 20% base + (completed/total * 60%) // Calculate progress: DATA_ANALYSIS% base + (completed/total * (TRAINING_RANGE_END - DATA_ANALYSIS)%)
progress = 20 + Math.floor((productsCompleted / totalProducts) * 60); const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS;
progress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth);
console.log('📦 Product training completed', console.log('📦 Product training completed',
`${productsCompleted}/${totalProducts}`, `${productsCompleted}/${totalProducts}`,
@@ -585,8 +595,8 @@ export const useTrainingWebSocket = (
// Detailed logging for different close codes // Detailed logging for different close codes
switch (event.code) { switch (event.code) {
case 1000: case 1000:
if (event.reason === 'Token refreshed - reconnecting') { if (event.reason === 'Auth session changed - reconnecting') {
console.log('🔄 WebSocket closed for token refresh - will reconnect immediately'); console.log('🔄 WebSocket closed for auth session change - will reconnect immediately');
} else { } else {
console.log('🔒 WebSocket closed normally'); console.log('🔒 WebSocket closed normally');
} }
@@ -604,18 +614,21 @@ export const useTrainingWebSocket = (
console.log(`❓ WebSocket closed with code ${event.code}`); console.log(`❓ WebSocket closed with code ${event.code}`);
} }
// Handle token refresh reconnection (immediate reconnect) // Handle auth session change reconnection (immediate reconnect)
if (event.code === 1000 && event.reason === 'Token refreshed - reconnecting') { if (event.code === 1000 && event.reason === 'Auth session changed - reconnecting') {
console.log('🔄 Reconnecting immediately due to token refresh...'); console.log('🔄 Reconnecting immediately due to auth session change...');
reconnectTimer = setTimeout(() => { reconnectTimer = setTimeout(() => {
connect(); // Reconnect immediately with fresh token connect(); // Reconnect immediately with new session token
}, 1000); // Short delay to allow cleanup }, WEBSOCKET_RECONNECT_INITIAL_DELAY_MS); // Short delay to allow cleanup
return; return;
} }
// Try to reconnect if not manually disconnected and haven't exceeded max attempts // Try to reconnect if not manually disconnected and haven't exceeded max attempts
if (!isManuallyDisconnected && event.code !== 1000 && reconnectAttempts < maxReconnectAttempts) { if (!isManuallyDisconnected && event.code !== 1000 && reconnectAttempts < maxReconnectAttempts) {
const delay = Math.min(1000 * Math.pow(2, reconnectAttempts), 10000); // Exponential backoff, max 10s const delay = Math.min(
WEBSOCKET_RECONNECT_INITIAL_DELAY_MS * Math.pow(2, reconnectAttempts),
WEBSOCKET_RECONNECT_MAX_DELAY_MS
); // Exponential backoff
console.log(`🔄 Attempting to reconnect WebSocket in ${delay/1000}s... (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})`); console.log(`🔄 Attempting to reconnect WebSocket in ${delay/1000}s... (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})`);
reconnectTimer = setTimeout(() => { reconnectTimer = setTimeout(() => {

View File

@@ -5,6 +5,11 @@ import { Button } from '../../../ui/Button';
import { useCurrentTenant } from '../../../../stores/tenant.store'; import { useCurrentTenant } from '../../../../stores/tenant.store';
import { useCreateTrainingJob, useTrainingWebSocket, useTrainingJobStatus } from '../../../../api/hooks/training'; import { useCreateTrainingJob, useTrainingWebSocket, useTrainingJobStatus } from '../../../../api/hooks/training';
import { Info } from 'lucide-react'; import { Info } from 'lucide-react';
import {
TRAINING_SKIP_OPTION_DELAY_MS,
TRAINING_COMPLETION_DELAY_MS,
SKIP_TIMER_CHECK_INTERVAL_MS
} from '../../../../constants/training';
interface MLTrainingStepProps { interface MLTrainingStepProps {
onNext: () => void; onNext: () => void;
@@ -38,16 +43,16 @@ export const MLTrainingStep: React.FC<MLTrainingStepProps> = ({
const currentTenant = useCurrentTenant(); const currentTenant = useCurrentTenant();
const createTrainingJob = useCreateTrainingJob(); const createTrainingJob = useCreateTrainingJob();
// Check if training has been running for more than 2 minutes // Check if training has been running for more than the skip delay threshold
useEffect(() => { useEffect(() => {
if (trainingStartTime && isTraining && !showSkipOption) { if (trainingStartTime && isTraining && !showSkipOption) {
const checkTimer = setInterval(() => { const checkTimer = setInterval(() => {
const elapsedTime = (Date.now() - trainingStartTime) / 1000; // in seconds const elapsedTime = (Date.now() - trainingStartTime) / 1000; // in seconds
if (elapsedTime > 120) { // 2 minutes if (elapsedTime > TRAINING_SKIP_OPTION_DELAY_MS / 1000) {
setShowSkipOption(true); setShowSkipOption(true);
clearInterval(checkTimer); clearInterval(checkTimer);
} }
}, 5000); // Check every 5 seconds }, SKIP_TIMER_CHECK_INTERVAL_MS);
return () => clearInterval(checkTimer); return () => clearInterval(checkTimer);
} }
@@ -72,14 +77,14 @@ export const MLTrainingStep: React.FC<MLTrainingStepProps> = ({
message: 'Entrenamiento completado exitosamente' message: 'Entrenamiento completado exitosamente'
}); });
setIsTraining(false); setIsTraining(false);
setTimeout(() => { setTimeout(() => {
onComplete({ onComplete({
jobId: jobId, jobId: jobId,
success: true, success: true,
message: 'Modelo entrenado correctamente' message: 'Modelo entrenado correctamente'
}); });
}, 2000); }, TRAINING_COMPLETION_DELAY_MS);
}, [onComplete, jobId]); }, [onComplete, jobId]);
const handleError = useCallback((data: any) => { const handleError = useCallback((data: any) => {
@@ -147,7 +152,7 @@ export const MLTrainingStep: React.FC<MLTrainingStepProps> = ({
message: 'Modelo entrenado correctamente', message: 'Modelo entrenado correctamente',
detectedViaPolling: true detectedViaPolling: true
}); });
}, 2000); }, TRAINING_COMPLETION_DELAY_MS);
} else if (jobStatus.status === 'failed') { } else if (jobStatus.status === 'failed') {
console.log(`❌ Training failure detected (source: ${isConnected ? 'WebSocket' : 'HTTP polling'})`); console.log(`❌ Training failure detected (source: ${isConnected ? 'WebSocket' : 'HTTP polling'})`);
setError('Error detectado durante el entrenamiento (verificación de estado)'); setError('Error detectado durante el entrenamiento (verificación de estado)');
@@ -169,13 +174,15 @@ export const MLTrainingStep: React.FC<MLTrainingStepProps> = ({
} }
}, [jobStatus, jobId, trainingProgress?.stage, onComplete, isConnected]); }, [jobStatus, jobId, trainingProgress?.stage, onComplete, isConnected]);
// Auto-trigger training when component mounts // Auto-trigger training when component mounts (run once)
const hasAutoStarted = React.useRef(false);
useEffect(() => { useEffect(() => {
if (currentTenant?.id && !isTraining && !trainingProgress && !error) { if (currentTenant?.id && !hasAutoStarted.current && !isTraining && !trainingProgress && !error) {
console.log('🚀 Auto-starting ML training for tenant:', currentTenant.id); console.log('🚀 Auto-starting ML training for tenant:', currentTenant.id);
hasAutoStarted.current = true;
handleStartTraining(); handleStartTraining();
} }
}, [currentTenant?.id]); // Only run when tenant is available }, [currentTenant?.id, isTraining, trainingProgress, error]); // Include all checked dependencies
const handleStartTraining = async () => { const handleStartTraining = async () => {
if (!currentTenant?.id) { if (!currentTenant?.id) {

View File

@@ -0,0 +1,25 @@
/**
* Training Progress Constants
* Centralized constants for training UI and behavior
*/
// Time Intervals (milliseconds)
export const TRAINING_SKIP_OPTION_DELAY_MS = 120000; // 2 minutes
export const TRAINING_COMPLETION_DELAY_MS = 2000; // 2 seconds
export const HTTP_POLLING_INTERVAL_MS = 5000; // 5 seconds
export const HTTP_POLLING_DEBOUNCE_MS = 5000; // 5 seconds
export const WEBSOCKET_HEARTBEAT_INTERVAL_MS = 30000; // 30 seconds
// WebSocket Configuration
export const WEBSOCKET_MAX_RECONNECT_ATTEMPTS = 3;
export const WEBSOCKET_RECONNECT_INITIAL_DELAY_MS = 1000; // 1 second
export const WEBSOCKET_RECONNECT_MAX_DELAY_MS = 10000; // 10 seconds
// Progress Milestones
export const PROGRESS_DATA_ANALYSIS = 20;
export const PROGRESS_TRAINING_RANGE_START = 20;
export const PROGRESS_TRAINING_RANGE_END = 80;
export const PROGRESS_COMPLETED = 100;
// Skip Timer Check Interval
export const SKIP_TIMER_CHECK_INTERVAL_MS = 5000; // 5 seconds

View File

@@ -189,15 +189,8 @@ async def start_training_job(
# Calculate estimated completion time # Calculate estimated completion time
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes) estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
# Publish training.started event immediately so WebSocket clients # Note: training.started event will be published by the trainer with accurate product count
# have initial state when they connect # We don't publish here to avoid duplicate events
await publish_training_started(
job_id=job_id,
tenant_id=tenant_id,
total_products=0, # Will be updated when actual training starts
estimated_duration_minutes=estimated_duration_minutes,
estimated_completion_time=estimated_completion_time.isoformat()
)
# Add enhanced background task # Add enhanced background task
background_tasks.add_task( background_tasks.add_task(
@@ -401,11 +394,6 @@ async def execute_training_job_background(
# Failure event is published by the training service # Failure event is published by the training service
await publish_training_failed(job_id, tenant_id, str(training_error)) await publish_training_failed(job_id, tenant_id, str(training_error))
except Exception as background_error:
logger.error("Critical error in enhanced background training job",
job_id=job_id,
error=str(background_error))
finally: finally:
logger.info("Enhanced background training job cleanup completed", logger.info("Enhanced background training job cleanup completed",
job_id=job_id) job_id=job_id)

View File

@@ -0,0 +1,35 @@
"""
Training Progress Constants
Centralized constants for training progress tracking and timing
"""
# Progress Milestones (percentage)
PROGRESS_STARTED = 0
PROGRESS_DATA_VALIDATION = 10
PROGRESS_DATA_ANALYSIS = 20
PROGRESS_DATA_PREPARATION_COMPLETE = 30
PROGRESS_ML_TRAINING_START = 40
PROGRESS_TRAINING_COMPLETE = 85
PROGRESS_STORING_MODELS = 92
PROGRESS_STORING_METRICS = 94
PROGRESS_COMPLETED = 100
# Progress Ranges
PROGRESS_TRAINING_RANGE_START = 20 # After data analysis
PROGRESS_TRAINING_RANGE_END = 80 # Before finalization
PROGRESS_TRAINING_RANGE_WIDTH = PROGRESS_TRAINING_RANGE_END - PROGRESS_TRAINING_RANGE_START # 60%
# Time Limits and Intervals (seconds)
MAX_ESTIMATED_TIME_REMAINING_SECONDS = 1800 # 30 minutes
WEBSOCKET_HEARTBEAT_INTERVAL_SECONDS = 30
WEBSOCKET_RECONNECT_MAX_ATTEMPTS = 3
WEBSOCKET_RECONNECT_INITIAL_DELAY_SECONDS = 1
WEBSOCKET_RECONNECT_MAX_DELAY_SECONDS = 10
# Training Timeouts (seconds)
TRAINING_SKIP_OPTION_DELAY_SECONDS = 120 # 2 minutes
HTTP_POLLING_INTERVAL_MS = 5000 # 5 seconds
HTTP_POLLING_DEBOUNCE_MS = 5000 # 5 seconds before enabling after WebSocket disconnect
# Frontend Display
TRAINING_COMPLETION_DELAY_MS = 2000 # Delay before navigating after completion

View File

@@ -6,7 +6,7 @@ Main ML pipeline coordinator using repository pattern for data access and depend
from typing import Dict, List, Any, Optional from typing import Dict, List, Any, Optional
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from datetime import datetime from datetime import datetime, timezone
import structlog import structlog
import uuid import uuid
import time import time
@@ -187,7 +187,10 @@ class EnhancedBakeryMLTrainer:
# Event 2: Data Analysis (20%) # Event 2: Data Analysis (20%)
# Recalculate time remaining based on elapsed time # Recalculate time remaining based on elapsed time
elapsed_seconds = (datetime.now(timezone.utc) - repos['training_log']._get_start_time(job_id) if hasattr(repos['training_log'], '_get_start_time') else 0) or 0 start_time = await repos['training_log'].get_start_time(job_id)
elapsed_seconds = 0
if start_time:
elapsed_seconds = int((datetime.now(timezone.utc) - start_time).total_seconds())
# Estimate remaining time: we've done ~20% of work (data analysis) # Estimate remaining time: we've done ~20% of work (data analysis)
# Remaining 80% includes training all products # Remaining 80% includes training all products
@@ -285,7 +288,8 @@ class EnhancedBakeryMLTrainer:
except Exception as e: except Exception as e:
logger.error("Enhanced ML training pipeline failed", logger.error("Enhanced ML training pipeline failed",
job_id=job_id, job_id=job_id,
error=str(e)) error=str(e),
exc_info=True)
# Publish training failed event # Publish training failed event
await publish_training_failed(job_id, tenant_id, str(e)) await publish_training_failed(job_id, tenant_id, str(e))
@@ -397,7 +401,8 @@ class EnhancedBakeryMLTrainer:
logger.error("Single product model training failed", logger.error("Single product model training failed",
job_id=job_id, job_id=job_id,
inventory_product_id=inventory_product_id, inventory_product_id=inventory_product_id,
error=str(e)) error=str(e),
exc_info=True)
raise raise
def _serialize_scalers(self, scalers: Dict[str, Any]) -> Dict[str, Any]: def _serialize_scalers(self, scalers: Dict[str, Any]) -> Dict[str, Any]:

View File

@@ -329,4 +329,17 @@ class TrainingLogRepository(TrainingBaseRepository):
"min_duration_minutes": 0.0, "min_duration_minutes": 0.0,
"max_duration_minutes": 0.0, "max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0 "completed_jobs_with_duration": 0
} }
async def get_start_time(self, job_id: str) -> Optional[datetime]:
"""Get the start time for a training job"""
try:
log_entry = await self.get_by_job_id(job_id)
if log_entry and log_entry.start_time:
return log_entry.start_time
return None
except Exception as e:
logger.error("Failed to get start time",
job_id=job_id,
error=str(e))
return None

View File

@@ -10,6 +10,11 @@ from datetime import datetime, timezone
from app.services.training_events import publish_product_training_completed from app.services.training_events import publish_product_training_completed
from app.utils.time_estimation import calculate_estimated_completion_time from app.utils.time_estimation import calculate_estimated_completion_time
from app.core.training_constants import (
PROGRESS_TRAINING_RANGE_START,
PROGRESS_TRAINING_RANGE_END,
PROGRESS_TRAINING_RANGE_WIDTH
)
logger = structlog.get_logger() logger = structlog.get_logger()
@@ -34,8 +39,8 @@ class ParallelProductProgressTracker:
self.start_time = datetime.now(timezone.utc) self.start_time = datetime.now(timezone.utc)
# Calculate progress increment per product # Calculate progress increment per product
# 60% of total progress (from 20% to 80%) divided by number of products # Training range (from PROGRESS_TRAINING_RANGE_START to PROGRESS_TRAINING_RANGE_END) divided by number of products
self.progress_per_product = 60 / total_products if total_products > 0 else 0 self.progress_per_product = PROGRESS_TRAINING_RANGE_WIDTH / total_products if total_products > 0 else 0
logger.info("ParallelProductProgressTracker initialized", logger.info("ParallelProductProgressTracker initialized",
job_id=job_id, job_id=job_id,
@@ -80,9 +85,9 @@ class ParallelProductProgressTracker:
estimated_completion_time=estimated_completion_time estimated_completion_time=estimated_completion_time
) )
# Calculate overall progress (20% base + progress from completed products) # Calculate overall progress (PROGRESS_TRAINING_RANGE_START% base + progress from completed products)
# This calculation is done on the frontend/consumer side based on the event data # This calculation is done on the frontend/consumer side based on the event data
overall_progress = 20 + int((current_progress / self.total_products) * 60) overall_progress = PROGRESS_TRAINING_RANGE_START + int((current_progress / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
logger.info("Product training completed", logger.info("Product training completed",
job_id=self.job_id, job_id=self.job_id,
@@ -99,5 +104,5 @@ class ParallelProductProgressTracker:
return { return {
"products_completed": self.products_completed, "products_completed": self.products_completed,
"total_products": self.total_products, "total_products": self.total_products,
"progress_percentage": 20 + int((self.products_completed / self.total_products) * 60) "progress_percentage": PROGRESS_TRAINING_RANGE_START + int((self.products_completed / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
} }

View File

@@ -135,6 +135,61 @@ async def publish_data_analysis(
return success return success
async def publish_training_progress(
job_id: str,
tenant_id: str,
progress: int,
current_step: str,
step_details: Optional[str] = None,
estimated_time_remaining_seconds: Optional[int] = None,
estimated_completion_time: Optional[str] = None
) -> bool:
"""
Generic Training Progress Event (for any progress percentage)
Args:
job_id: Training job identifier
tenant_id: Tenant identifier
progress: Progress percentage (0-100)
current_step: Current step name
step_details: Details about the current step
estimated_time_remaining_seconds: Estimated time remaining in seconds
estimated_completion_time: ISO timestamp of estimated completion
"""
event_data = {
"service_name": "training-service",
"event_type": "training.progress",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"progress": progress,
"current_step": current_step,
"step_details": step_details or current_step,
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
"estimated_completion_time": estimated_completion_time
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.progress",
event_data=event_data
)
if success:
logger.info("Published training progress event",
job_id=job_id,
progress=progress,
current_step=current_step)
else:
logger.error("Failed to publish training progress event",
job_id=job_id,
progress=progress)
return success
async def publish_product_training_completed( async def publish_product_training_completed(
job_id: str, job_id: str,
tenant_id: str, tenant_id: str,

View File

@@ -16,6 +16,15 @@ import pandas as pd
from app.ml.trainer import EnhancedBakeryMLTrainer from app.ml.trainer import EnhancedBakeryMLTrainer
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
from app.services.training_orchestrator import TrainingDataOrchestrator from app.services.training_orchestrator import TrainingDataOrchestrator
from app.core.training_constants import (
PROGRESS_DATA_VALIDATION,
PROGRESS_DATA_PREPARATION_COMPLETE,
PROGRESS_ML_TRAINING_START,
PROGRESS_TRAINING_COMPLETE,
PROGRESS_STORING_MODELS,
PROGRESS_STORING_METRICS,
MAX_ESTIMATED_TIME_REMAINING_SECONDS
)
# Import repositories # Import repositories
from app.repositories import ( from app.repositories import (
@@ -187,7 +196,7 @@ class EnhancedTrainingService:
# Step 1: Prepare training dataset (includes sales data validation) # Step 1: Prepare training dataset (includes sales data validation)
logger.info("Step 1: Preparing and aligning training data (with validation)") logger.info("Step 1: Preparing and aligning training data (with validation)")
await self.training_log_repo.update_log_progress( await self.training_log_repo.update_log_progress(
job_id, 10, "data_validation", "running" job_id, PROGRESS_DATA_VALIDATION, "data_validation", "running"
) )
# Orchestrator now handles sales data validation to eliminate duplicate fetching # Orchestrator now handles sales data validation to eliminate duplicate fetching
@@ -204,13 +213,13 @@ class EnhancedTrainingService:
tenant_id=tenant_id, job_id=job_id) tenant_id=tenant_id, job_id=job_id)
await self.training_log_repo.update_log_progress( await self.training_log_repo.update_log_progress(
job_id, 30, "data_preparation_complete", "running" job_id, PROGRESS_DATA_PREPARATION_COMPLETE, "data_preparation_complete", "running"
) )
# Step 2: Execute ML training pipeline # Step 2: Execute ML training pipeline
logger.info("Step 2: Starting ML training pipeline") logger.info("Step 2: Starting ML training pipeline")
await self.training_log_repo.update_log_progress( await self.training_log_repo.update_log_progress(
job_id, 40, "ml_training", "running" job_id, PROGRESS_ML_TRAINING_START, "ml_training", "running"
) )
training_results = await self.trainer.train_tenant_models( training_results = await self.trainer.train_tenant_models(
@@ -220,9 +229,19 @@ class EnhancedTrainingService:
) )
await self.training_log_repo.update_log_progress( await self.training_log_repo.update_log_progress(
job_id, 85, "training_complete", "running" job_id, PROGRESS_TRAINING_COMPLETE, "training_complete", "running"
) )
# Publish progress event (85%)
from app.services.training_events import publish_training_progress
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_TRAINING_COMPLETE,
current_step="Training Complete",
step_details="All products trained successfully"
)
# Step 3: Store model records using repository # Step 3: Store model records using repository
logger.info("Step 3: Storing model records") logger.info("Step 3: Storing model records")
logger.debug("Training results structure", logger.debug("Training results structure",
@@ -234,12 +253,30 @@ class EnhancedTrainingService:
) )
await self.training_log_repo.update_log_progress( await self.training_log_repo.update_log_progress(
job_id, 92, "storing_models", "running" job_id, PROGRESS_STORING_MODELS, "storing_models", "running"
)
# Publish progress event (92%)
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_STORING_MODELS,
current_step="Storing Models",
step_details=f"Saved {len(stored_models)} trained models to database"
) )
# Step 4: Create performance metrics # Step 4: Create performance metrics
await self.training_log_repo.update_log_progress( await self.training_log_repo.update_log_progress(
job_id, 94, "storing_performance_metrics", "running" job_id, PROGRESS_STORING_METRICS, "storing_performance_metrics", "running"
)
# Publish progress event (94%)
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_STORING_METRICS,
current_step="Storing Performance Metrics",
step_details="Saving model performance metrics"
) )
await self._create_performance_metrics( await self._create_performance_metrics(
@@ -316,8 +353,9 @@ class EnhancedTrainingService:
except Exception as e: except Exception as e:
logger.error("Enhanced training job failed", logger.error("Enhanced training job failed",
job_id=job_id, job_id=job_id,
error=str(e)) error=str(e),
exc_info=True)
# Mark as failed in database # Mark as failed in database
await self.training_log_repo.complete_training_log( await self.training_log_repo.complete_training_log(
job_id, error_message=str(e) job_id, error_message=str(e)
@@ -519,12 +557,13 @@ class EnhancedTrainingService:
if log.status == "running" and log.progress > 0 and log.start_time: if log.status == "running" and log.progress > 0 and log.start_time:
from datetime import datetime, timezone from datetime import datetime, timezone
elapsed_time = (datetime.now(timezone.utc) - log.start_time).total_seconds() elapsed_time = (datetime.now(timezone.utc) - log.start_time).total_seconds()
if elapsed_time > 0: if elapsed_time > 0 and log.progress > 0: # Double-check progress is positive
# Calculate estimated total time based on progress # Calculate estimated total time based on progress
estimated_total_time = (elapsed_time / log.progress) * 100 # Use max(log.progress, 1) as additional safety against division by zero
estimated_total_time = (elapsed_time / max(log.progress, 1)) * 100
estimated_time_remaining_seconds = int(estimated_total_time - elapsed_time) estimated_time_remaining_seconds = int(estimated_total_time - elapsed_time)
# Cap at reasonable maximum (e.g., 30 minutes) # Cap at reasonable maximum (30 minutes)
estimated_time_remaining_seconds = min(estimated_time_remaining_seconds, 1800) estimated_time_remaining_seconds = max(0, min(estimated_time_remaining_seconds, MAX_ESTIMATED_TIME_REMAINING_SECONDS))
# Extract products info from results if available # Extract products info from results if available
products_total = 0 products_total = 0