Fix multiple critical bugs in onboarding training step

This commit addresses all identified bugs and issues in the training code path:

## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"

## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes

## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
  - Backend: services/training/app/core/training_constants.py
  - Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors

## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns

## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)

All training progress events now properly flow from 0% to 100% with no gaps.
This commit is contained in:
Claude
2025-11-05 13:02:39 +00:00
parent e3ea92640b
commit 5a84be83d6
10 changed files with 291 additions and 106 deletions

View File

@@ -8,6 +8,17 @@ import { useMutation, useQuery, useQueryClient, UseQueryOptions, UseMutationOpti
import { trainingService } from '../services/training';
import { ApiError, apiClient } from '../client/apiClient';
import { useAuthStore } from '../../stores/auth.store';
import {
HTTP_POLLING_INTERVAL_MS,
HTTP_POLLING_DEBOUNCE_MS,
WEBSOCKET_HEARTBEAT_INTERVAL_MS,
WEBSOCKET_MAX_RECONNECT_ATTEMPTS,
WEBSOCKET_RECONNECT_INITIAL_DELAY_MS,
WEBSOCKET_RECONNECT_MAX_DELAY_MS,
PROGRESS_DATA_ANALYSIS,
PROGRESS_TRAINING_RANGE_START,
PROGRESS_TRAINING_RANGE_END
} from '../../constants/training';
import type {
TrainingJobRequest,
TrainingJobResponse,
@@ -53,14 +64,32 @@ export const useTrainingJobStatus = (
}
) => {
const { isWebSocketConnected, ...queryOptions } = options || {};
const [enablePolling, setEnablePolling] = React.useState(false);
// Completely disable the query when WebSocket is connected
const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected;
// Debounce HTTP polling activation: wait after WebSocket disconnects
// This prevents race conditions where both WebSocket and HTTP are briefly active
React.useEffect(() => {
if (!isWebSocketConnected) {
const debounceTimer = setTimeout(() => {
setEnablePolling(true);
console.log(`🔄 HTTP polling enabled after ${HTTP_POLLING_DEBOUNCE_MS}ms debounce (WebSocket disconnected)`);
}, HTTP_POLLING_DEBOUNCE_MS);
return () => clearTimeout(debounceTimer);
} else {
setEnablePolling(false);
console.log('❌ HTTP polling disabled (WebSocket connected)');
}
}, [isWebSocketConnected]);
// Completely disable the query when WebSocket is connected or during debounce period
const isEnabled = !!tenantId && !!jobId && !isWebSocketConnected && enablePolling;
console.log('🔄 Training status query:', {
tenantId: !!tenantId,
jobId: !!jobId,
isWebSocketConnected,
enablePolling,
queryEnabled: isEnabled
});
@@ -85,8 +114,8 @@ export const useTrainingJobStatus = (
return false; // Stop polling when training is done
}
console.log('📊 HTTP fallback polling active (WebSocket disconnected) - 5s interval');
return 5000; // Poll every 5 seconds while training (fallback when WebSocket unavailable)
console.log(`📊 HTTP fallback polling active (WebSocket disconnected) - ${HTTP_POLLING_INTERVAL_MS}ms interval`);
return HTTP_POLLING_INTERVAL_MS; // Poll while training (fallback when WebSocket unavailable)
} : false, // Completely disable interval when WebSocket connected
staleTime: 1000, // Consider data stale after 1 second
retry: (failureCount, error) => {
@@ -298,7 +327,7 @@ export const useTrainingWebSocket = (
let reconnectTimer: NodeJS.Timeout | null = null;
let isManuallyDisconnected = false;
let reconnectAttempts = 0;
const maxReconnectAttempts = 3;
const maxReconnectAttempts = WEBSOCKET_MAX_RECONNECT_ATTEMPTS;
const connect = async () => {
try {
@@ -349,70 +378,49 @@ export const useTrainingWebSocket = (
console.warn('Failed to request status on connection:', e);
}
// Helper function to check if tokens represent different auth sessions
const isTokenSessionDifferent = (oldToken: string, newToken: string): boolean => {
// Helper function to check if tokens represent different auth users/sessions
const isNewAuthSession = (oldToken: string, newToken: string): boolean => {
if (!oldToken || !newToken) return !!oldToken !== !!newToken;
try {
const oldPayload = JSON.parse(atob(oldToken.split('.')[1]));
const newPayload = JSON.parse(atob(newToken.split('.')[1]));
// Compare by issued timestamp (iat) - different iat means new auth session
return oldPayload.iat !== newPayload.iat;
// Compare by user ID - different user means new auth session
// If user_id is same, it's just a token refresh, no need to reconnect
return oldPayload.user_id !== newPayload.user_id ||
oldPayload.sub !== newPayload.sub;
} catch (e) {
console.warn('Failed to parse token for session comparison, falling back to string comparison:', e);
return oldToken !== newToken;
console.warn('Failed to parse token for session comparison:', e);
// On parse error, don't reconnect (assume same session)
return false;
}
};
// Set up periodic ping and intelligent token refresh detection
// Set up periodic ping and check for auth session changes
const heartbeatInterval = setInterval(async () => {
if (ws?.readyState === WebSocket.OPEN && !isManuallyDisconnected) {
try {
// Check token validity (this may refresh if needed)
const currentToken = await apiClient.ensureValidToken();
// Enhanced token change detection with detailed logging
const tokenStringChanged = currentToken !== effectiveToken;
const tokenSessionChanged = currentToken && effectiveToken ?
isTokenSessionDifferent(effectiveToken, currentToken) : tokenStringChanged;
console.log('🔍 WebSocket token validation check:', {
hasCurrentToken: !!currentToken,
hasEffectiveToken: !!effectiveToken,
tokenStringChanged,
tokenSessionChanged,
currentTokenPreview: currentToken ? `${currentToken.slice(0, 20)}...${currentToken.slice(-10)}` : 'null',
effectiveTokenPreview: effectiveToken ? `${effectiveToken.slice(0, 20)}...${effectiveToken.slice(-10)}` : 'null'
});
// Only reconnect if we have a genuine session change (different iat)
if (tokenSessionChanged) {
console.log('🔄 Token session changed - reconnecting WebSocket with new session token');
console.log('📊 Session change details:', {
reason: !currentToken ? 'token removed' :
!effectiveToken ? 'token added' : 'new auth session',
oldTokenIat: effectiveToken ? (() => {
try { return JSON.parse(atob(effectiveToken.split('.')[1])).iat; } catch { return 'parse-error'; }
})() : 'N/A',
newTokenIat: currentToken ? (() => {
try { return JSON.parse(atob(currentToken.split('.')[1])).iat; } catch { return 'parse-error'; }
})() : 'N/A'
});
// Close current connection and trigger reconnection with new token
ws?.close(1000, 'Token session changed - reconnecting');
// Only reconnect if user changed (new auth session)
if (currentToken && effectiveToken && isNewAuthSession(effectiveToken, currentToken)) {
console.log('🔄 Auth session changed (different user) - reconnecting WebSocket');
ws?.close(1000, 'Auth session changed - reconnecting');
clearInterval(heartbeatInterval);
return;
} else if (tokenStringChanged) {
console.log(' Token string changed but same session - continuing with current connection');
// Update effective token reference for future comparisons
}
// Token may have been refreshed but it's the same user - continue
if (currentToken && currentToken !== effectiveToken) {
console.log(' Token refreshed (same user) - updating reference');
effectiveToken = currentToken;
}
console.log('✅ Token validated during heartbeat - same session');
// Send ping
ws?.send('ping');
console.log('💓 Sent ping to server (token session validated)');
console.log('💓 Sent ping to server');
} catch (e) {
console.warn('Failed to send ping or validate token:', e);
clearInterval(heartbeatInterval);
@@ -420,7 +428,7 @@ export const useTrainingWebSocket = (
} else {
clearInterval(heartbeatInterval);
}
}, 30000); // Check every 30 seconds for token refresh and send ping
}, WEBSOCKET_HEARTBEAT_INTERVAL_MS); // Check for auth changes and send ping
// Store interval for cleanup
(ws as any).heartbeatInterval = heartbeatInterval;
@@ -449,7 +457,8 @@ export const useTrainingWebSocket = (
if (initialData.type === 'product_completed') {
const productsCompleted = initialEventData.products_completed || 0;
const totalProducts = initialEventData.total_products || 1;
initialProgress = 20 + Math.floor((productsCompleted / totalProducts) * 60);
const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS;
initialProgress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth);
console.log('📦 Product training completed in initial state',
`${productsCompleted}/${totalProducts}`,
`progress: ${initialProgress}%`);
@@ -486,8 +495,9 @@ export const useTrainingWebSocket = (
const productsCompleted = eventData.products_completed || 0;
const totalProducts = eventData.total_products || 1;
// Calculate progress: 20% base + (completed/total * 60%)
progress = 20 + Math.floor((productsCompleted / totalProducts) * 60);
// Calculate progress: DATA_ANALYSIS% base + (completed/total * (TRAINING_RANGE_END - DATA_ANALYSIS)%)
const trainingRangeWidth = PROGRESS_TRAINING_RANGE_END - PROGRESS_DATA_ANALYSIS;
progress = PROGRESS_DATA_ANALYSIS + Math.floor((productsCompleted / totalProducts) * trainingRangeWidth);
console.log('📦 Product training completed',
`${productsCompleted}/${totalProducts}`,
@@ -585,8 +595,8 @@ export const useTrainingWebSocket = (
// Detailed logging for different close codes
switch (event.code) {
case 1000:
if (event.reason === 'Token refreshed - reconnecting') {
console.log('🔄 WebSocket closed for token refresh - will reconnect immediately');
if (event.reason === 'Auth session changed - reconnecting') {
console.log('🔄 WebSocket closed for auth session change - will reconnect immediately');
} else {
console.log('🔒 WebSocket closed normally');
}
@@ -604,18 +614,21 @@ export const useTrainingWebSocket = (
console.log(`❓ WebSocket closed with code ${event.code}`);
}
// Handle token refresh reconnection (immediate reconnect)
if (event.code === 1000 && event.reason === 'Token refreshed - reconnecting') {
console.log('🔄 Reconnecting immediately due to token refresh...');
// Handle auth session change reconnection (immediate reconnect)
if (event.code === 1000 && event.reason === 'Auth session changed - reconnecting') {
console.log('🔄 Reconnecting immediately due to auth session change...');
reconnectTimer = setTimeout(() => {
connect(); // Reconnect immediately with fresh token
}, 1000); // Short delay to allow cleanup
connect(); // Reconnect immediately with new session token
}, WEBSOCKET_RECONNECT_INITIAL_DELAY_MS); // Short delay to allow cleanup
return;
}
// Try to reconnect if not manually disconnected and haven't exceeded max attempts
if (!isManuallyDisconnected && event.code !== 1000 && reconnectAttempts < maxReconnectAttempts) {
const delay = Math.min(1000 * Math.pow(2, reconnectAttempts), 10000); // Exponential backoff, max 10s
const delay = Math.min(
WEBSOCKET_RECONNECT_INITIAL_DELAY_MS * Math.pow(2, reconnectAttempts),
WEBSOCKET_RECONNECT_MAX_DELAY_MS
); // Exponential backoff
console.log(`🔄 Attempting to reconnect WebSocket in ${delay/1000}s... (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})`);
reconnectTimer = setTimeout(() => {

View File

@@ -5,6 +5,11 @@ import { Button } from '../../../ui/Button';
import { useCurrentTenant } from '../../../../stores/tenant.store';
import { useCreateTrainingJob, useTrainingWebSocket, useTrainingJobStatus } from '../../../../api/hooks/training';
import { Info } from 'lucide-react';
import {
TRAINING_SKIP_OPTION_DELAY_MS,
TRAINING_COMPLETION_DELAY_MS,
SKIP_TIMER_CHECK_INTERVAL_MS
} from '../../../../constants/training';
interface MLTrainingStepProps {
onNext: () => void;
@@ -38,16 +43,16 @@ export const MLTrainingStep: React.FC<MLTrainingStepProps> = ({
const currentTenant = useCurrentTenant();
const createTrainingJob = useCreateTrainingJob();
// Check if training has been running for more than 2 minutes
// Check if training has been running for more than the skip delay threshold
useEffect(() => {
if (trainingStartTime && isTraining && !showSkipOption) {
const checkTimer = setInterval(() => {
const elapsedTime = (Date.now() - trainingStartTime) / 1000; // in seconds
if (elapsedTime > 120) { // 2 minutes
if (elapsedTime > TRAINING_SKIP_OPTION_DELAY_MS / 1000) {
setShowSkipOption(true);
clearInterval(checkTimer);
}
}, 5000); // Check every 5 seconds
}, SKIP_TIMER_CHECK_INTERVAL_MS);
return () => clearInterval(checkTimer);
}
@@ -79,7 +84,7 @@ export const MLTrainingStep: React.FC<MLTrainingStepProps> = ({
success: true,
message: 'Modelo entrenado correctamente'
});
}, 2000);
}, TRAINING_COMPLETION_DELAY_MS);
}, [onComplete, jobId]);
const handleError = useCallback((data: any) => {
@@ -147,7 +152,7 @@ export const MLTrainingStep: React.FC<MLTrainingStepProps> = ({
message: 'Modelo entrenado correctamente',
detectedViaPolling: true
});
}, 2000);
}, TRAINING_COMPLETION_DELAY_MS);
} else if (jobStatus.status === 'failed') {
console.log(`❌ Training failure detected (source: ${isConnected ? 'WebSocket' : 'HTTP polling'})`);
setError('Error detectado durante el entrenamiento (verificación de estado)');
@@ -169,13 +174,15 @@ export const MLTrainingStep: React.FC<MLTrainingStepProps> = ({
}
}, [jobStatus, jobId, trainingProgress?.stage, onComplete, isConnected]);
// Auto-trigger training when component mounts
// Auto-trigger training when component mounts (run once)
const hasAutoStarted = React.useRef(false);
useEffect(() => {
if (currentTenant?.id && !isTraining && !trainingProgress && !error) {
if (currentTenant?.id && !hasAutoStarted.current && !isTraining && !trainingProgress && !error) {
console.log('🚀 Auto-starting ML training for tenant:', currentTenant.id);
hasAutoStarted.current = true;
handleStartTraining();
}
}, [currentTenant?.id]); // Only run when tenant is available
}, [currentTenant?.id, isTraining, trainingProgress, error]); // Include all checked dependencies
const handleStartTraining = async () => {
if (!currentTenant?.id) {

View File

@@ -0,0 +1,25 @@
/**
* Training Progress Constants
* Centralized constants for training UI and behavior
*/
// Time Intervals (milliseconds)
export const TRAINING_SKIP_OPTION_DELAY_MS = 120000; // 2 minutes
export const TRAINING_COMPLETION_DELAY_MS = 2000; // 2 seconds
export const HTTP_POLLING_INTERVAL_MS = 5000; // 5 seconds
export const HTTP_POLLING_DEBOUNCE_MS = 5000; // 5 seconds
export const WEBSOCKET_HEARTBEAT_INTERVAL_MS = 30000; // 30 seconds
// WebSocket Configuration
export const WEBSOCKET_MAX_RECONNECT_ATTEMPTS = 3;
export const WEBSOCKET_RECONNECT_INITIAL_DELAY_MS = 1000; // 1 second
export const WEBSOCKET_RECONNECT_MAX_DELAY_MS = 10000; // 10 seconds
// Progress Milestones
export const PROGRESS_DATA_ANALYSIS = 20;
export const PROGRESS_TRAINING_RANGE_START = 20;
export const PROGRESS_TRAINING_RANGE_END = 80;
export const PROGRESS_COMPLETED = 100;
// Skip Timer Check Interval
export const SKIP_TIMER_CHECK_INTERVAL_MS = 5000; // 5 seconds

View File

@@ -189,15 +189,8 @@ async def start_training_job(
# Calculate estimated completion time
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
# Publish training.started event immediately so WebSocket clients
# have initial state when they connect
await publish_training_started(
job_id=job_id,
tenant_id=tenant_id,
total_products=0, # Will be updated when actual training starts
estimated_duration_minutes=estimated_duration_minutes,
estimated_completion_time=estimated_completion_time.isoformat()
)
# Note: training.started event will be published by the trainer with accurate product count
# We don't publish here to avoid duplicate events
# Add enhanced background task
background_tasks.add_task(
@@ -401,11 +394,6 @@ async def execute_training_job_background(
# Failure event is published by the training service
await publish_training_failed(job_id, tenant_id, str(training_error))
except Exception as background_error:
logger.error("Critical error in enhanced background training job",
job_id=job_id,
error=str(background_error))
finally:
logger.info("Enhanced background training job cleanup completed",
job_id=job_id)

View File

@@ -0,0 +1,35 @@
"""
Training Progress Constants
Centralized constants for training progress tracking and timing
"""
# Progress Milestones (percentage)
PROGRESS_STARTED = 0
PROGRESS_DATA_VALIDATION = 10
PROGRESS_DATA_ANALYSIS = 20
PROGRESS_DATA_PREPARATION_COMPLETE = 30
PROGRESS_ML_TRAINING_START = 40
PROGRESS_TRAINING_COMPLETE = 85
PROGRESS_STORING_MODELS = 92
PROGRESS_STORING_METRICS = 94
PROGRESS_COMPLETED = 100
# Progress Ranges
PROGRESS_TRAINING_RANGE_START = 20 # After data analysis
PROGRESS_TRAINING_RANGE_END = 80 # Before finalization
PROGRESS_TRAINING_RANGE_WIDTH = PROGRESS_TRAINING_RANGE_END - PROGRESS_TRAINING_RANGE_START # 60%
# Time Limits and Intervals (seconds)
MAX_ESTIMATED_TIME_REMAINING_SECONDS = 1800 # 30 minutes
WEBSOCKET_HEARTBEAT_INTERVAL_SECONDS = 30
WEBSOCKET_RECONNECT_MAX_ATTEMPTS = 3
WEBSOCKET_RECONNECT_INITIAL_DELAY_SECONDS = 1
WEBSOCKET_RECONNECT_MAX_DELAY_SECONDS = 10
# Training Timeouts (seconds)
TRAINING_SKIP_OPTION_DELAY_SECONDS = 120 # 2 minutes
HTTP_POLLING_INTERVAL_MS = 5000 # 5 seconds
HTTP_POLLING_DEBOUNCE_MS = 5000 # 5 seconds before enabling after WebSocket disconnect
# Frontend Display
TRAINING_COMPLETION_DELAY_MS = 2000 # Delay before navigating after completion

View File

@@ -6,7 +6,7 @@ Main ML pipeline coordinator using repository pattern for data access and depend
from typing import Dict, List, Any, Optional
import pandas as pd
import numpy as np
from datetime import datetime
from datetime import datetime, timezone
import structlog
import uuid
import time
@@ -187,7 +187,10 @@ class EnhancedBakeryMLTrainer:
# Event 2: Data Analysis (20%)
# Recalculate time remaining based on elapsed time
elapsed_seconds = (datetime.now(timezone.utc) - repos['training_log']._get_start_time(job_id) if hasattr(repos['training_log'], '_get_start_time') else 0) or 0
start_time = await repos['training_log'].get_start_time(job_id)
elapsed_seconds = 0
if start_time:
elapsed_seconds = int((datetime.now(timezone.utc) - start_time).total_seconds())
# Estimate remaining time: we've done ~20% of work (data analysis)
# Remaining 80% includes training all products
@@ -285,7 +288,8 @@ class EnhancedBakeryMLTrainer:
except Exception as e:
logger.error("Enhanced ML training pipeline failed",
job_id=job_id,
error=str(e))
error=str(e),
exc_info=True)
# Publish training failed event
await publish_training_failed(job_id, tenant_id, str(e))
@@ -397,7 +401,8 @@ class EnhancedBakeryMLTrainer:
logger.error("Single product model training failed",
job_id=job_id,
inventory_product_id=inventory_product_id,
error=str(e))
error=str(e),
exc_info=True)
raise
def _serialize_scalers(self, scalers: Dict[str, Any]) -> Dict[str, Any]:

View File

@@ -330,3 +330,16 @@ class TrainingLogRepository(TrainingBaseRepository):
"max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0
}
async def get_start_time(self, job_id: str) -> Optional[datetime]:
"""Get the start time for a training job"""
try:
log_entry = await self.get_by_job_id(job_id)
if log_entry and log_entry.start_time:
return log_entry.start_time
return None
except Exception as e:
logger.error("Failed to get start time",
job_id=job_id,
error=str(e))
return None

View File

@@ -10,6 +10,11 @@ from datetime import datetime, timezone
from app.services.training_events import publish_product_training_completed
from app.utils.time_estimation import calculate_estimated_completion_time
from app.core.training_constants import (
PROGRESS_TRAINING_RANGE_START,
PROGRESS_TRAINING_RANGE_END,
PROGRESS_TRAINING_RANGE_WIDTH
)
logger = structlog.get_logger()
@@ -34,8 +39,8 @@ class ParallelProductProgressTracker:
self.start_time = datetime.now(timezone.utc)
# Calculate progress increment per product
# 60% of total progress (from 20% to 80%) divided by number of products
self.progress_per_product = 60 / total_products if total_products > 0 else 0
# Training range (from PROGRESS_TRAINING_RANGE_START to PROGRESS_TRAINING_RANGE_END) divided by number of products
self.progress_per_product = PROGRESS_TRAINING_RANGE_WIDTH / total_products if total_products > 0 else 0
logger.info("ParallelProductProgressTracker initialized",
job_id=job_id,
@@ -80,9 +85,9 @@ class ParallelProductProgressTracker:
estimated_completion_time=estimated_completion_time
)
# Calculate overall progress (20% base + progress from completed products)
# Calculate overall progress (PROGRESS_TRAINING_RANGE_START% base + progress from completed products)
# This calculation is done on the frontend/consumer side based on the event data
overall_progress = 20 + int((current_progress / self.total_products) * 60)
overall_progress = PROGRESS_TRAINING_RANGE_START + int((current_progress / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
logger.info("Product training completed",
job_id=self.job_id,
@@ -99,5 +104,5 @@ class ParallelProductProgressTracker:
return {
"products_completed": self.products_completed,
"total_products": self.total_products,
"progress_percentage": 20 + int((self.products_completed / self.total_products) * 60)
"progress_percentage": PROGRESS_TRAINING_RANGE_START + int((self.products_completed / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
}

View File

@@ -135,6 +135,61 @@ async def publish_data_analysis(
return success
async def publish_training_progress(
job_id: str,
tenant_id: str,
progress: int,
current_step: str,
step_details: Optional[str] = None,
estimated_time_remaining_seconds: Optional[int] = None,
estimated_completion_time: Optional[str] = None
) -> bool:
"""
Generic Training Progress Event (for any progress percentage)
Args:
job_id: Training job identifier
tenant_id: Tenant identifier
progress: Progress percentage (0-100)
current_step: Current step name
step_details: Details about the current step
estimated_time_remaining_seconds: Estimated time remaining in seconds
estimated_completion_time: ISO timestamp of estimated completion
"""
event_data = {
"service_name": "training-service",
"event_type": "training.progress",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"progress": progress,
"current_step": current_step,
"step_details": step_details or current_step,
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
"estimated_completion_time": estimated_completion_time
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.progress",
event_data=event_data
)
if success:
logger.info("Published training progress event",
job_id=job_id,
progress=progress,
current_step=current_step)
else:
logger.error("Failed to publish training progress event",
job_id=job_id,
progress=progress)
return success
async def publish_product_training_completed(
job_id: str,
tenant_id: str,

View File

@@ -16,6 +16,15 @@ import pandas as pd
from app.ml.trainer import EnhancedBakeryMLTrainer
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
from app.services.training_orchestrator import TrainingDataOrchestrator
from app.core.training_constants import (
PROGRESS_DATA_VALIDATION,
PROGRESS_DATA_PREPARATION_COMPLETE,
PROGRESS_ML_TRAINING_START,
PROGRESS_TRAINING_COMPLETE,
PROGRESS_STORING_MODELS,
PROGRESS_STORING_METRICS,
MAX_ESTIMATED_TIME_REMAINING_SECONDS
)
# Import repositories
from app.repositories import (
@@ -187,7 +196,7 @@ class EnhancedTrainingService:
# Step 1: Prepare training dataset (includes sales data validation)
logger.info("Step 1: Preparing and aligning training data (with validation)")
await self.training_log_repo.update_log_progress(
job_id, 10, "data_validation", "running"
job_id, PROGRESS_DATA_VALIDATION, "data_validation", "running"
)
# Orchestrator now handles sales data validation to eliminate duplicate fetching
@@ -204,13 +213,13 @@ class EnhancedTrainingService:
tenant_id=tenant_id, job_id=job_id)
await self.training_log_repo.update_log_progress(
job_id, 30, "data_preparation_complete", "running"
job_id, PROGRESS_DATA_PREPARATION_COMPLETE, "data_preparation_complete", "running"
)
# Step 2: Execute ML training pipeline
logger.info("Step 2: Starting ML training pipeline")
await self.training_log_repo.update_log_progress(
job_id, 40, "ml_training", "running"
job_id, PROGRESS_ML_TRAINING_START, "ml_training", "running"
)
training_results = await self.trainer.train_tenant_models(
@@ -220,7 +229,17 @@ class EnhancedTrainingService:
)
await self.training_log_repo.update_log_progress(
job_id, 85, "training_complete", "running"
job_id, PROGRESS_TRAINING_COMPLETE, "training_complete", "running"
)
# Publish progress event (85%)
from app.services.training_events import publish_training_progress
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_TRAINING_COMPLETE,
current_step="Training Complete",
step_details="All products trained successfully"
)
# Step 3: Store model records using repository
@@ -234,12 +253,30 @@ class EnhancedTrainingService:
)
await self.training_log_repo.update_log_progress(
job_id, 92, "storing_models", "running"
job_id, PROGRESS_STORING_MODELS, "storing_models", "running"
)
# Publish progress event (92%)
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_STORING_MODELS,
current_step="Storing Models",
step_details=f"Saved {len(stored_models)} trained models to database"
)
# Step 4: Create performance metrics
await self.training_log_repo.update_log_progress(
job_id, 94, "storing_performance_metrics", "running"
job_id, PROGRESS_STORING_METRICS, "storing_performance_metrics", "running"
)
# Publish progress event (94%)
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_STORING_METRICS,
current_step="Storing Performance Metrics",
step_details="Saving model performance metrics"
)
await self._create_performance_metrics(
@@ -316,7 +353,8 @@ class EnhancedTrainingService:
except Exception as e:
logger.error("Enhanced training job failed",
job_id=job_id,
error=str(e))
error=str(e),
exc_info=True)
# Mark as failed in database
await self.training_log_repo.complete_training_log(
@@ -519,12 +557,13 @@ class EnhancedTrainingService:
if log.status == "running" and log.progress > 0 and log.start_time:
from datetime import datetime, timezone
elapsed_time = (datetime.now(timezone.utc) - log.start_time).total_seconds()
if elapsed_time > 0:
if elapsed_time > 0 and log.progress > 0: # Double-check progress is positive
# Calculate estimated total time based on progress
estimated_total_time = (elapsed_time / log.progress) * 100
# Use max(log.progress, 1) as additional safety against division by zero
estimated_total_time = (elapsed_time / max(log.progress, 1)) * 100
estimated_time_remaining_seconds = int(estimated_total_time - elapsed_time)
# Cap at reasonable maximum (e.g., 30 minutes)
estimated_time_remaining_seconds = min(estimated_time_remaining_seconds, 1800)
# Cap at reasonable maximum (30 minutes)
estimated_time_remaining_seconds = max(0, min(estimated_time_remaining_seconds, MAX_ESTIMATED_TIME_REMAINING_SECONDS))
# Extract products info from results if available
products_total = 0