From 5a84be83d688b03b08eed5a879bbd82cebdbe19d Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 5 Nov 2025 13:02:39 +0000 Subject: [PATCH] Fix multiple critical bugs in onboarding training step This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps. --- frontend/src/api/hooks/training.ts | 131 ++++++++++-------- .../onboarding/steps/MLTrainingStep.tsx | 25 ++-- frontend/src/constants/training.ts | 25 ++++ .../training/app/api/training_operations.py | 16 +-- .../training/app/core/training_constants.py | 35 +++++ services/training/app/ml/trainer.py | 13 +- .../repositories/training_log_repository.py | 15 +- .../training/app/services/progress_tracker.py | 15 +- .../training/app/services/training_events.py | 55 ++++++++ .../training/app/services/training_service.py | 67 +++++++-- 10 files changed, 291 insertions(+), 106 deletions(-) create mode 100644 frontend/src/constants/training.ts create mode 100644 services/training/app/core/training_constants.py diff --git a/frontend/src/api/hooks/training.ts b/frontend/src/api/hooks/training.ts index eaca92d4..1cdd0ec5 100644 --- a/frontend/src/api/hooks/training.ts +++ b/frontend/src/api/hooks/training.ts @@ -8,6 +8,17 @@ import { useMutation, useQuery, useQueryClient, UseQueryOptions, UseMutationOpti import { trainingService } from '../services/training'; import { ApiError, apiClient } from '../client/apiClient'; import { useAuthStore } from '../../stores/auth.store'; +import { + HTTP_POLLING_INTERVAL_MS, + HTTP_POLLING_DEBOUNCE_MS, + WEBSOCKET_HEARTBEAT_INTERVAL_MS, + WEBSOCKET_MAX_RECONNECT_ATTEMPTS, + WEBSOCKET_RECONNECT_INITIAL_DELAY_MS, + WEBSOCKET_RECONNECT_MAX_DELAY_MS, + PROGRESS_DATA_ANALYSIS, + PROGRESS_TRAINING_RANGE_START, + PROGRESS_TRAINING_RANGE_END +} from '../../constants/training'; import type { TrainingJobRequest, TrainingJobResponse, @@ -53,14 +64,32 @@ export const useTrainingJobStatus = ( } ) => { const { isWebSocketConnected, ...queryOptions } = options || {}; + const [enablePolling, setEnablePolling] = React.useState(false); - // Completely disable the query when WebSocket is connected - const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected; + // Debounce HTTP polling activation: wait after WebSocket disconnects + // This prevents race conditions where both WebSocket and HTTP are briefly active + React.useEffect(() => { + if (!isWebSocketConnected) { + const debounceTimer = setTimeout(() => { + setEnablePolling(true); + console.log(`πŸ”„ HTTP polling enabled after ${HTTP_POLLING_DEBOUNCE_MS}ms debounce (WebSocket disconnected)`); + }, HTTP_POLLING_DEBOUNCE_MS); + + return () => clearTimeout(debounceTimer); + } else { + setEnablePolling(false); + console.log('❌ HTTP polling disabled (WebSocket connected)'); + } + }, [isWebSocketConnected]); + + // Completely disable the query when WebSocket is connected or during debounce period + const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected && enablePolling; console.log('πŸ”„ Training status query:', { tenantId: !!tenantId, jobId: !!jobId, isWebSocketConnected, + enablePolling, queryEnabled: isEnabled }); @@ -85,8 +114,8 @@ export const useTrainingJobStatus = ( return false; // Stop polling when training is done } - console.log('πŸ“Š HTTP fallback polling active (WebSocket disconnected) - 5s interval'); - return 5000; // Poll every 5 seconds while training (fallback when WebSocket unavailable) + console.log(`πŸ“Š HTTP fallback polling active (WebSocket disconnected) - ${HTTP_POLLING_INTERVAL_MS}ms interval`); + return HTTP_POLLING_INTERVAL_MS; // Poll while training (fallback when WebSocket unavailable) } : false, // Completely disable interval when WebSocket connected staleTime: 1000, // Consider data stale after 1 second retry: (failureCount, error) => { @@ -298,7 +327,7 @@ export const useTrainingWebSocket = ( let reconnectTimer: NodeJS.Timeout | null = null; let isManuallyDisconnected = false; let reconnectAttempts = 0; - const maxReconnectAttempts = 3; + const maxReconnectAttempts = WEBSOCKET_MAX_RECONNECT_ATTEMPTS; const connect = async () => { try { @@ -349,70 +378,49 @@ export const useTrainingWebSocket = ( console.warn('Failed to request status on connection:', e); } - // Helper function to check if tokens represent different auth sessions - const isTokenSessionDifferent = (oldToken: string, newToken: string): boolean => { + // Helper function to check if tokens represent different auth users/sessions + const isNewAuthSession = (oldToken: string, newToken: string): boolean => { if (!oldToken || !newToken) return !!oldToken !== !!newToken; try { const oldPayload = JSON.parse(atob(oldToken.split('.')[1])); const newPayload = JSON.parse(atob(newToken.split('.')[1])); - // Compare by issued timestamp (iat) - different iat means new auth session - return oldPayload.iat !== newPayload.iat; + // Compare by user ID - different user means new auth session + // If user_id is same, it's just a token refresh, no need to reconnect + return oldPayload.user_id !== newPayload.user_id || + oldPayload.sub !== newPayload.sub; } catch (e) { - console.warn('Failed to parse token for session comparison, falling back to string comparison:', e); - return oldToken !== newToken; + console.warn('Failed to parse token for session comparison:', e); + // On parse error, don't reconnect (assume same session) + return false; } }; - // Set up periodic ping and intelligent token refresh detection + // Set up periodic ping and check for auth session changes const heartbeatInterval = setInterval(async () => { if (ws?.readyState === WebSocket.OPEN && !isManuallyDisconnected) { try { // Check token validity (this may refresh if needed) const currentToken = await apiClient.ensureValidToken(); - // Enhanced token change detection with detailed logging - const tokenStringChanged = currentToken !== effectiveToken; - const tokenSessionChanged = currentToken && effectiveToken ? - isTokenSessionDifferent(effectiveToken, currentToken) : tokenStringChanged; - - console.log('πŸ” WebSocket token validation check:', { - hasCurrentToken: !!currentToken, - hasEffectiveToken: !!effectiveToken, - tokenStringChanged, - tokenSessionChanged, - currentTokenPreview: currentToken ? `${currentToken.slice(0, 20)}...${currentToken.slice(-10)}` : 'null', - effectiveTokenPreview: effectiveToken ? `${effectiveToken.slice(0, 20)}...${effectiveToken.slice(-10)}` : 'null' - }); - - // Only reconnect if we have a genuine session change (different iat) - if (tokenSessionChanged) { - console.log('πŸ”„ Token session changed - reconnecting WebSocket with new session token'); - console.log('πŸ“Š Session change details:', { - reason: !currentToken ? 'token removed' : - !effectiveToken ? 'token added' : 'new auth session', - oldTokenIat: effectiveToken ? (() => { - try { return JSON.parse(atob(effectiveToken.split('.')[1])).iat; } catch { return 'parse-error'; } - })() : 'N/A', - newTokenIat: currentToken ? (() => { - try { return JSON.parse(atob(currentToken.split('.')[1])).iat; } catch { return 'parse-error'; } - })() : 'N/A' - }); - - // Close current connection and trigger reconnection with new token - ws?.close(1000, 'Token session changed - reconnecting'); + // Only reconnect if user changed (new auth session) + if (currentToken && effectiveToken && isNewAuthSession(effectiveToken, currentToken)) { + console.log('πŸ”„ Auth session changed (different user) - reconnecting WebSocket'); + ws?.close(1000, 'Auth session changed - reconnecting'); clearInterval(heartbeatInterval); return; - } else if (tokenStringChanged) { - console.log('ℹ️ Token string changed but same session - continuing with current connection'); - // Update effective token reference for future comparisons + } + + // Token may have been refreshed but it's the same user - continue + if (currentToken && currentToken !== effectiveToken) { + console.log('ℹ️ Token refreshed (same user) - updating reference'); effectiveToken = currentToken; } - console.log('βœ… Token validated during heartbeat - same session'); + // Send ping ws?.send('ping'); - console.log('πŸ’“ Sent ping to server (token session validated)'); + console.log('πŸ’“ Sent ping to server'); } catch (e) { console.warn('Failed to send ping or validate token:', e); clearInterval(heartbeatInterval); @@ -420,7 +428,7 @@ export const useTrainingWebSocket = ( } else { clearInterval(heartbeatInterval); } - }, 30000); // Check every 30 seconds for token refresh and send ping + }, WEBSOCKET_HEARTBEAT_INTERVAL_MS); // Check for auth changes and send ping // Store interval for cleanup (ws as any).heartbeatInterval = heartbeatInterval; @@ -449,7 +457,8 @@ export const useTrainingWebSocket = ( if (initialData.type === 'product_completed') { const productsCompleted = initialEventData.products_completed || 0; const totalProducts = initialEventData.total_products || 1; - initialProgress = 20 + Math.floor((productsCompleted / totalProducts) * 60); + const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS; + initialProgress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth); console.log('πŸ“¦ Product training completed in initial state', `${productsCompleted}/${totalProducts}`, `progress: ${initialProgress}%`); @@ -486,8 +495,9 @@ export const useTrainingWebSocket = ( const productsCompleted = eventData.products_completed || 0; const totalProducts = eventData.total_products || 1; - // Calculate progress: 20% base + (completed/total * 60%) - progress = 20 + Math.floor((productsCompleted / totalProducts) * 60); + // Calculate progress: DATA_ANALYSIS% base + (completed/total * (TRAINING_RANGE_END - DATA_ANALYSIS)%) + const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS; + progress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth); console.log('πŸ“¦ Product training completed', `${productsCompleted}/${totalProducts}`, @@ -585,8 +595,8 @@ export const useTrainingWebSocket = ( // Detailed logging for different close codes switch (event.code) { case 1000: - if (event.reason === 'Token refreshed - reconnecting') { - console.log('πŸ”„ WebSocket closed for token refresh - will reconnect immediately'); + if (event.reason === 'Auth session changed - reconnecting') { + console.log('πŸ”„ WebSocket closed for auth session change - will reconnect immediately'); } else { console.log('πŸ”’ WebSocket closed normally'); } @@ -604,18 +614,21 @@ export const useTrainingWebSocket = ( console.log(`❓ WebSocket closed with code ${event.code}`); } - // Handle token refresh reconnection (immediate reconnect) - if (event.code === 1000 && event.reason === 'Token refreshed - reconnecting') { - console.log('πŸ”„ Reconnecting immediately due to token refresh...'); + // Handle auth session change reconnection (immediate reconnect) + if (event.code === 1000 && event.reason === 'Auth session changed - reconnecting') { + console.log('πŸ”„ Reconnecting immediately due to auth session change...'); reconnectTimer = setTimeout(() => { - connect(); // Reconnect immediately with fresh token - }, 1000); // Short delay to allow cleanup + connect(); // Reconnect immediately with new session token + }, WEBSOCKET_RECONNECT_INITIAL_DELAY_MS); // Short delay to allow cleanup return; } // Try to reconnect if not manually disconnected and haven't exceeded max attempts if (!isManuallyDisconnected && event.code !== 1000 && reconnectAttempts < maxReconnectAttempts) { - const delay = Math.min(1000 * Math.pow(2, reconnectAttempts), 10000); // Exponential backoff, max 10s + const delay = Math.min( + WEBSOCKET_RECONNECT_INITIAL_DELAY_MS * Math.pow(2, reconnectAttempts), + WEBSOCKET_RECONNECT_MAX_DELAY_MS + ); // Exponential backoff console.log(`πŸ”„ Attempting to reconnect WebSocket in ${delay/1000}s... (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})`); reconnectTimer = setTimeout(() => { diff --git a/frontend/src/components/domain/onboarding/steps/MLTrainingStep.tsx b/frontend/src/components/domain/onboarding/steps/MLTrainingStep.tsx index f63a0cc3..e534e14a 100644 --- a/frontend/src/components/domain/onboarding/steps/MLTrainingStep.tsx +++ b/frontend/src/components/domain/onboarding/steps/MLTrainingStep.tsx @@ -5,6 +5,11 @@ import { Button } from '../../../ui/Button'; import { useCurrentTenant } from '../../../../stores/tenant.store'; import { useCreateTrainingJob, useTrainingWebSocket, useTrainingJobStatus } from '../../../../api/hooks/training'; import { Info } from 'lucide-react'; +import { + TRAINING_SKIP_OPTION_DELAY_MS, + TRAINING_COMPLETION_DELAY_MS, + SKIP_TIMER_CHECK_INTERVAL_MS +} from '../../../../constants/training'; interface MLTrainingStepProps { onNext: () => void; @@ -38,16 +43,16 @@ export const MLTrainingStep: React.FC = ({ const currentTenant = useCurrentTenant(); const createTrainingJob = useCreateTrainingJob(); - // Check if training has been running for more than 2 minutes + // Check if training has been running for more than the skip delay threshold useEffect(() => { if (trainingStartTime && isTraining && !showSkipOption) { const checkTimer = setInterval(() => { const elapsedTime = (Date.now() - trainingStartTime) / 1000; // in seconds - if (elapsedTime > 120) { // 2 minutes + if (elapsedTime > TRAINING_SKIP_OPTION_DELAY_MS / 1000) { setShowSkipOption(true); clearInterval(checkTimer); } - }, 5000); // Check every 5 seconds + }, SKIP_TIMER_CHECK_INTERVAL_MS); return () => clearInterval(checkTimer); } @@ -72,14 +77,14 @@ export const MLTrainingStep: React.FC = ({ message: 'Entrenamiento completado exitosamente' }); setIsTraining(false); - + setTimeout(() => { onComplete({ jobId: jobId, success: true, message: 'Modelo entrenado correctamente' }); - }, 2000); + }, TRAINING_COMPLETION_DELAY_MS); }, [onComplete, jobId]); const handleError = useCallback((data: any) => { @@ -147,7 +152,7 @@ export const MLTrainingStep: React.FC = ({ message: 'Modelo entrenado correctamente', detectedViaPolling: true }); - }, 2000); + }, TRAINING_COMPLETION_DELAY_MS); } else if (jobStatus.status === 'failed') { console.log(`❌ Training failure detected (source: ${isConnected ? 'WebSocket' : 'HTTP polling'})`); setError('Error detectado durante el entrenamiento (verificaciΓ³n de estado)'); @@ -169,13 +174,15 @@ export const MLTrainingStep: React.FC = ({ } }, [jobStatus, jobId, trainingProgress?.stage, onComplete, isConnected]); - // Auto-trigger training when component mounts + // Auto-trigger training when component mounts (run once) + const hasAutoStarted = React.useRef(false); useEffect(() => { - if (currentTenant?.id && !isTraining && !trainingProgress && !error) { + if (currentTenant?.id && !hasAutoStarted.current && !isTraining && !trainingProgress && !error) { console.log('πŸš€ Auto-starting ML training for tenant:', currentTenant.id); + hasAutoStarted.current = true; handleStartTraining(); } - }, [currentTenant?.id]); // Only run when tenant is available + }, [currentTenant?.id, isTraining, trainingProgress, error]); // Include all checked dependencies const handleStartTraining = async () => { if (!currentTenant?.id) { diff --git a/frontend/src/constants/training.ts b/frontend/src/constants/training.ts new file mode 100644 index 00000000..40ad8b1a --- /dev/null +++ b/frontend/src/constants/training.ts @@ -0,0 +1,25 @@ +/** + * Training Progress Constants + * Centralized constants for training UI and behavior + */ + +// Time Intervals (milliseconds) +export const TRAINING_SKIP_OPTION_DELAY_MS = 120000; // 2 minutes +export const TRAINING_COMPLETION_DELAY_MS = 2000; // 2 seconds +export const HTTP_POLLING_INTERVAL_MS = 5000; // 5 seconds +export const HTTP_POLLING_DEBOUNCE_MS = 5000; // 5 seconds +export const WEBSOCKET_HEARTBEAT_INTERVAL_MS = 30000; // 30 seconds + +// WebSocket Configuration +export const WEBSOCKET_MAX_RECONNECT_ATTEMPTS = 3; +export const WEBSOCKET_RECONNECT_INITIAL_DELAY_MS = 1000; // 1 second +export const WEBSOCKET_RECONNECT_MAX_DELAY_MS = 10000; // 10 seconds + +// Progress Milestones +export const PROGRESS_DATA_ANALYSIS = 20; +export const PROGRESS_TRAINING_RANGE_START = 20; +export const PROGRESS_TRAINING_RANGE_END = 80; +export const PROGRESS_COMPLETED = 100; + +// Skip Timer Check Interval +export const SKIP_TIMER_CHECK_INTERVAL_MS = 5000; // 5 seconds diff --git a/services/training/app/api/training_operations.py b/services/training/app/api/training_operations.py index 007bf6d8..46e1d69d 100644 --- a/services/training/app/api/training_operations.py +++ b/services/training/app/api/training_operations.py @@ -189,15 +189,8 @@ async def start_training_job( # Calculate estimated completion time estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes) - # Publish training.started event immediately so WebSocket clients - # have initial state when they connect - await publish_training_started( - job_id=job_id, - tenant_id=tenant_id, - total_products=0, # Will be updated when actual training starts - estimated_duration_minutes=estimated_duration_minutes, - estimated_completion_time=estimated_completion_time.isoformat() - ) + # Note: training.started event will be published by the trainer with accurate product count + # We don't publish here to avoid duplicate events # Add enhanced background task background_tasks.add_task( @@ -401,11 +394,6 @@ async def execute_training_job_background( # Failure event is published by the training service await publish_training_failed(job_id, tenant_id, str(training_error)) - except Exception as background_error: - logger.error("Critical error in enhanced background training job", - job_id=job_id, - error=str(background_error)) - finally: logger.info("Enhanced background training job cleanup completed", job_id=job_id) diff --git a/services/training/app/core/training_constants.py b/services/training/app/core/training_constants.py new file mode 100644 index 00000000..b9be856f --- /dev/null +++ b/services/training/app/core/training_constants.py @@ -0,0 +1,35 @@ +""" +Training Progress Constants +Centralized constants for training progress tracking and timing +""" + +# Progress Milestones (percentage) +PROGRESS_STARTED = 0 +PROGRESS_DATA_VALIDATION = 10 +PROGRESS_DATA_ANALYSIS = 20 +PROGRESS_DATA_PREPARATION_COMPLETE = 30 +PROGRESS_ML_TRAINING_START = 40 +PROGRESS_TRAINING_COMPLETE = 85 +PROGRESS_STORING_MODELS = 92 +PROGRESS_STORING_METRICS = 94 +PROGRESS_COMPLETED = 100 + +# Progress Ranges +PROGRESS_TRAINING_RANGE_START = 20 # After data analysis +PROGRESS_TRAINING_RANGE_END = 80 # Before finalization +PROGRESS_TRAINING_RANGE_WIDTH = PROGRESS_TRAINING_RANGE_END - PROGRESS_TRAINING_RANGE_START # 60% + +# Time Limits and Intervals (seconds) +MAX_ESTIMATED_TIME_REMAINING_SECONDS = 1800 # 30 minutes +WEBSOCKET_HEARTBEAT_INTERVAL_SECONDS = 30 +WEBSOCKET_RECONNECT_MAX_ATTEMPTS = 3 +WEBSOCKET_RECONNECT_INITIAL_DELAY_SECONDS = 1 +WEBSOCKET_RECONNECT_MAX_DELAY_SECONDS = 10 + +# Training Timeouts (seconds) +TRAINING_SKIP_OPTION_DELAY_SECONDS = 120 # 2 minutes +HTTP_POLLING_INTERVAL_MS = 5000 # 5 seconds +HTTP_POLLING_DEBOUNCE_MS = 5000 # 5 seconds before enabling after WebSocket disconnect + +# Frontend Display +TRAINING_COMPLETION_DELAY_MS = 2000 # Delay before navigating after completion diff --git a/services/training/app/ml/trainer.py b/services/training/app/ml/trainer.py index 9695d373..d7d249cd 100644 --- a/services/training/app/ml/trainer.py +++ b/services/training/app/ml/trainer.py @@ -6,7 +6,7 @@ Main ML pipeline coordinator using repository pattern for data access and depend from typing import Dict, List, Any, Optional import pandas as pd import numpy as np -from datetime import datetime +from datetime import datetime, timezone import structlog import uuid import time @@ -187,7 +187,10 @@ class EnhancedBakeryMLTrainer: # Event 2: Data Analysis (20%) # Recalculate time remaining based on elapsed time - elapsed_seconds = (datetime.now(timezone.utc) - repos['training_log']._get_start_time(job_id) if hasattr(repos['training_log'], '_get_start_time') else 0) or 0 + start_time = await repos['training_log'].get_start_time(job_id) + elapsed_seconds = 0 + if start_time: + elapsed_seconds = int((datetime.now(timezone.utc) - start_time).total_seconds()) # Estimate remaining time: we've done ~20% of work (data analysis) # Remaining 80% includes training all products @@ -285,7 +288,8 @@ class EnhancedBakeryMLTrainer: except Exception as e: logger.error("Enhanced ML training pipeline failed", job_id=job_id, - error=str(e)) + error=str(e), + exc_info=True) # Publish training failed event await publish_training_failed(job_id, tenant_id, str(e)) @@ -397,7 +401,8 @@ class EnhancedBakeryMLTrainer: logger.error("Single product model training failed", job_id=job_id, inventory_product_id=inventory_product_id, - error=str(e)) + error=str(e), + exc_info=True) raise def _serialize_scalers(self, scalers: Dict[str, Any]) -> Dict[str, Any]: diff --git a/services/training/app/repositories/training_log_repository.py b/services/training/app/repositories/training_log_repository.py index feebcafd..eee6aef3 100644 --- a/services/training/app/repositories/training_log_repository.py +++ b/services/training/app/repositories/training_log_repository.py @@ -329,4 +329,17 @@ class TrainingLogRepository(TrainingBaseRepository): "min_duration_minutes": 0.0, "max_duration_minutes": 0.0, "completed_jobs_with_duration": 0 - } \ No newline at end of file + } + + async def get_start_time(self, job_id: str) -> Optional[datetime]: + """Get the start time for a training job""" + try: + log_entry = await self.get_by_job_id(job_id) + if log_entry and log_entry.start_time: + return log_entry.start_time + return None + except Exception as e: + logger.error("Failed to get start time", + job_id=job_id, + error=str(e)) + return None \ No newline at end of file diff --git a/services/training/app/services/progress_tracker.py b/services/training/app/services/progress_tracker.py index e7798dbd..5d2bade6 100644 --- a/services/training/app/services/progress_tracker.py +++ b/services/training/app/services/progress_tracker.py @@ -10,6 +10,11 @@ from datetime import datetime, timezone from app.services.training_events import publish_product_training_completed from app.utils.time_estimation import calculate_estimated_completion_time +from app.core.training_constants import ( + PROGRESS_TRAINING_RANGE_START, + PROGRESS_TRAINING_RANGE_END, + PROGRESS_TRAINING_RANGE_WIDTH +) logger = structlog.get_logger() @@ -34,8 +39,8 @@ class ParallelProductProgressTracker: self.start_time = datetime.now(timezone.utc) # Calculate progress increment per product - # 60% of total progress (from 20% to 80%) divided by number of products - self.progress_per_product = 60 / total_products if total_products > 0 else 0 + # Training range (from PROGRESS_TRAINING_RANGE_START to PROGRESS_TRAINING_RANGE_END) divided by number of products + self.progress_per_product = PROGRESS_TRAINING_RANGE_WIDTH / total_products if total_products > 0 else 0 logger.info("ParallelProductProgressTracker initialized", job_id=job_id, @@ -80,9 +85,9 @@ class ParallelProductProgressTracker: estimated_completion_time=estimated_completion_time ) - # Calculate overall progress (20% base + progress from completed products) + # Calculate overall progress (PROGRESS_TRAINING_RANGE_START% base + progress from completed products) # This calculation is done on the frontend/consumer side based on the event data - overall_progress = 20 + int((current_progress / self.total_products) * 60) + overall_progress = PROGRESS_TRAINING_RANGE_START + int((current_progress / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH) logger.info("Product training completed", job_id=self.job_id, @@ -99,5 +104,5 @@ class ParallelProductProgressTracker: return { "products_completed": self.products_completed, "total_products": self.total_products, - "progress_percentage": 20 + int((self.products_completed / self.total_products) * 60) + "progress_percentage": PROGRESS_TRAINING_RANGE_START + int((self.products_completed / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH) } diff --git a/services/training/app/services/training_events.py b/services/training/app/services/training_events.py index 5d649d5c..ffcd899c 100644 --- a/services/training/app/services/training_events.py +++ b/services/training/app/services/training_events.py @@ -135,6 +135,61 @@ async def publish_data_analysis( return success +async def publish_training_progress( + job_id: str, + tenant_id: str, + progress: int, + current_step: str, + step_details: Optional[str] = None, + estimated_time_remaining_seconds: Optional[int] = None, + estimated_completion_time: Optional[str] = None +) -> bool: + """ + Generic Training Progress Event (for any progress percentage) + + Args: + job_id: Training job identifier + tenant_id: Tenant identifier + progress: Progress percentage (0-100) + current_step: Current step name + step_details: Details about the current step + estimated_time_remaining_seconds: Estimated time remaining in seconds + estimated_completion_time: ISO timestamp of estimated completion + """ + event_data = { + "service_name": "training-service", + "event_type": "training.progress", + "timestamp": datetime.now().isoformat(), + "data": { + "job_id": job_id, + "tenant_id": tenant_id, + "progress": progress, + "current_step": current_step, + "step_details": step_details or current_step, + "estimated_time_remaining_seconds": estimated_time_remaining_seconds, + "estimated_completion_time": estimated_completion_time + } + } + + success = await training_publisher.publish_event( + exchange_name="training.events", + routing_key="training.progress", + event_data=event_data + ) + + if success: + logger.info("Published training progress event", + job_id=job_id, + progress=progress, + current_step=current_step) + else: + logger.error("Failed to publish training progress event", + job_id=job_id, + progress=progress) + + return success + + async def publish_product_training_completed( job_id: str, tenant_id: str, diff --git a/services/training/app/services/training_service.py b/services/training/app/services/training_service.py index 8d8fa61a..36af1e54 100644 --- a/services/training/app/services/training_service.py +++ b/services/training/app/services/training_service.py @@ -16,6 +16,15 @@ import pandas as pd from app.ml.trainer import EnhancedBakeryMLTrainer from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType from app.services.training_orchestrator import TrainingDataOrchestrator +from app.core.training_constants import ( + PROGRESS_DATA_VALIDATION, + PROGRESS_DATA_PREPARATION_COMPLETE, + PROGRESS_ML_TRAINING_START, + PROGRESS_TRAINING_COMPLETE, + PROGRESS_STORING_MODELS, + PROGRESS_STORING_METRICS, + MAX_ESTIMATED_TIME_REMAINING_SECONDS +) # Import repositories from app.repositories import ( @@ -187,7 +196,7 @@ class EnhancedTrainingService: # Step 1: Prepare training dataset (includes sales data validation) logger.info("Step 1: Preparing and aligning training data (with validation)") await self.training_log_repo.update_log_progress( - job_id, 10, "data_validation", "running" + job_id, PROGRESS_DATA_VALIDATION, "data_validation", "running" ) # Orchestrator now handles sales data validation to eliminate duplicate fetching @@ -204,13 +213,13 @@ class EnhancedTrainingService: tenant_id=tenant_id, job_id=job_id) await self.training_log_repo.update_log_progress( - job_id, 30, "data_preparation_complete", "running" + job_id, PROGRESS_DATA_PREPARATION_COMPLETE, "data_preparation_complete", "running" ) - + # Step 2: Execute ML training pipeline logger.info("Step 2: Starting ML training pipeline") await self.training_log_repo.update_log_progress( - job_id, 40, "ml_training", "running" + job_id, PROGRESS_ML_TRAINING_START, "ml_training", "running" ) training_results = await self.trainer.train_tenant_models( @@ -220,9 +229,19 @@ class EnhancedTrainingService: ) await self.training_log_repo.update_log_progress( - job_id, 85, "training_complete", "running" + job_id, PROGRESS_TRAINING_COMPLETE, "training_complete", "running" ) - + + # Publish progress event (85%) + from app.services.training_events import publish_training_progress + await publish_training_progress( + job_id=job_id, + tenant_id=tenant_id, + progress=PROGRESS_TRAINING_COMPLETE, + current_step="Training Complete", + step_details="All products trained successfully" + ) + # Step 3: Store model records using repository logger.info("Step 3: Storing model records") logger.debug("Training results structure", @@ -234,12 +253,30 @@ class EnhancedTrainingService: ) await self.training_log_repo.update_log_progress( - job_id, 92, "storing_models", "running" + job_id, PROGRESS_STORING_MODELS, "storing_models", "running" + ) + + # Publish progress event (92%) + await publish_training_progress( + job_id=job_id, + tenant_id=tenant_id, + progress=PROGRESS_STORING_MODELS, + current_step="Storing Models", + step_details=f"Saved {len(stored_models)} trained models to database" ) # Step 4: Create performance metrics await self.training_log_repo.update_log_progress( - job_id, 94, "storing_performance_metrics", "running" + job_id, PROGRESS_STORING_METRICS, "storing_performance_metrics", "running" + ) + + # Publish progress event (94%) + await publish_training_progress( + job_id=job_id, + tenant_id=tenant_id, + progress=PROGRESS_STORING_METRICS, + current_step="Storing Performance Metrics", + step_details="Saving model performance metrics" ) await self._create_performance_metrics( @@ -316,8 +353,9 @@ class EnhancedTrainingService: except Exception as e: logger.error("Enhanced training job failed", job_id=job_id, - error=str(e)) - + error=str(e), + exc_info=True) + # Mark as failed in database await self.training_log_repo.complete_training_log( job_id, error_message=str(e) @@ -519,12 +557,13 @@ class EnhancedTrainingService: if log.status == "running" and log.progress > 0 and log.start_time: from datetime import datetime, timezone elapsed_time = (datetime.now(timezone.utc) - log.start_time).total_seconds() - if elapsed_time > 0: + if elapsed_time > 0 and log.progress > 0: # Double-check progress is positive # Calculate estimated total time based on progress - estimated_total_time = (elapsed_time / log.progress) * 100 + # Use max(log.progress, 1) as additional safety against division by zero + estimated_total_time = (elapsed_time / max(log.progress, 1)) * 100 estimated_time_remaining_seconds = int(estimated_total_time - elapsed_time) - # Cap at reasonable maximum (e.g., 30 minutes) - estimated_time_remaining_seconds = min(estimated_time_remaining_seconds, 1800) + # Cap at reasonable maximum (30 minutes) + estimated_time_remaining_seconds = max(0, min(estimated_time_remaining_seconds, MAX_ESTIMATED_TIME_REMAINING_SECONDS)) # Extract products info from results if available products_total = 0