Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path: ## Critical Fixes: - Add get_start_time() method to TrainingLogRepository and fix non-existent method call - Remove duplicate training.started event from API endpoint (trainer publishes the accurate one) - Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone" ## High Priority Fixes: - Fix division by zero risk in time estimation with double-check and max() safety - Remove unreachable exception handler in training_operations.py - Simplify WebSocket token refresh logic to only reconnect on actual user session changes ## Medium Priority Fixes: - Fix auto-start training effect with useRef to prevent duplicate starts - Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket - Extract all magic numbers to centralized constants files: - Backend: services/training/app/core/training_constants.py - Frontend: frontend/src/constants/training.ts - Standardize error logging with exc_info=True on critical errors ## Code Quality Improvements: - All progress percentages now use named constants - All timeouts and intervals now use named constants - Improved code maintainability and readability - Better separation of concerns ## Files Changed: - Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py - Backend: training_operations.py, training_log_repository.py, training_constants.py (new) - Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new) All training progress events now properly flow from 0% to 100% with no gaps.
This commit is contained in:
@@ -189,15 +189,8 @@ async def start_training_job(
|
||||
# Calculate estimated completion time
|
||||
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
|
||||
|
||||
# Publish training.started event immediately so WebSocket clients
|
||||
# have initial state when they connect
|
||||
await publish_training_started(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=0, # Will be updated when actual training starts
|
||||
estimated_duration_minutes=estimated_duration_minutes,
|
||||
estimated_completion_time=estimated_completion_time.isoformat()
|
||||
)
|
||||
# Note: training.started event will be published by the trainer with accurate product count
|
||||
# We don't publish here to avoid duplicate events
|
||||
|
||||
# Add enhanced background task
|
||||
background_tasks.add_task(
|
||||
@@ -401,11 +394,6 @@ async def execute_training_job_background(
|
||||
# Failure event is published by the training service
|
||||
await publish_training_failed(job_id, tenant_id, str(training_error))
|
||||
|
||||
except Exception as background_error:
|
||||
logger.error("Critical error in enhanced background training job",
|
||||
job_id=job_id,
|
||||
error=str(background_error))
|
||||
|
||||
finally:
|
||||
logger.info("Enhanced background training job cleanup completed",
|
||||
job_id=job_id)
|
||||
|
||||
35
services/training/app/core/training_constants.py
Normal file
35
services/training/app/core/training_constants.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Training Progress Constants
|
||||
Centralized constants for training progress tracking and timing
|
||||
"""
|
||||
|
||||
# Progress Milestones (percentage)
|
||||
PROGRESS_STARTED = 0
|
||||
PROGRESS_DATA_VALIDATION = 10
|
||||
PROGRESS_DATA_ANALYSIS = 20
|
||||
PROGRESS_DATA_PREPARATION_COMPLETE = 30
|
||||
PROGRESS_ML_TRAINING_START = 40
|
||||
PROGRESS_TRAINING_COMPLETE = 85
|
||||
PROGRESS_STORING_MODELS = 92
|
||||
PROGRESS_STORING_METRICS = 94
|
||||
PROGRESS_COMPLETED = 100
|
||||
|
||||
# Progress Ranges
|
||||
PROGRESS_TRAINING_RANGE_START = 20 # After data analysis
|
||||
PROGRESS_TRAINING_RANGE_END = 80 # Before finalization
|
||||
PROGRESS_TRAINING_RANGE_WIDTH = PROGRESS_TRAINING_RANGE_END - PROGRESS_TRAINING_RANGE_START # 60%
|
||||
|
||||
# Time Limits and Intervals (seconds)
|
||||
MAX_ESTIMATED_TIME_REMAINING_SECONDS = 1800 # 30 minutes
|
||||
WEBSOCKET_HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
WEBSOCKET_RECONNECT_MAX_ATTEMPTS = 3
|
||||
WEBSOCKET_RECONNECT_INITIAL_DELAY_SECONDS = 1
|
||||
WEBSOCKET_RECONNECT_MAX_DELAY_SECONDS = 10
|
||||
|
||||
# Training Timeouts (seconds)
|
||||
TRAINING_SKIP_OPTION_DELAY_SECONDS = 120 # 2 minutes
|
||||
HTTP_POLLING_INTERVAL_MS = 5000 # 5 seconds
|
||||
HTTP_POLLING_DEBOUNCE_MS = 5000 # 5 seconds before enabling after WebSocket disconnect
|
||||
|
||||
# Frontend Display
|
||||
TRAINING_COMPLETION_DELAY_MS = 2000 # Delay before navigating after completion
|
||||
@@ -6,7 +6,7 @@ Main ML pipeline coordinator using repository pattern for data access and depend
|
||||
from typing import Dict, List, Any, Optional
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
import time
|
||||
@@ -187,7 +187,10 @@ class EnhancedBakeryMLTrainer:
|
||||
|
||||
# Event 2: Data Analysis (20%)
|
||||
# Recalculate time remaining based on elapsed time
|
||||
elapsed_seconds = (datetime.now(timezone.utc) - repos['training_log']._get_start_time(job_id) if hasattr(repos['training_log'], '_get_start_time') else 0) or 0
|
||||
start_time = await repos['training_log'].get_start_time(job_id)
|
||||
elapsed_seconds = 0
|
||||
if start_time:
|
||||
elapsed_seconds = int((datetime.now(timezone.utc) - start_time).total_seconds())
|
||||
|
||||
# Estimate remaining time: we've done ~20% of work (data analysis)
|
||||
# Remaining 80% includes training all products
|
||||
@@ -285,7 +288,8 @@ class EnhancedBakeryMLTrainer:
|
||||
except Exception as e:
|
||||
logger.error("Enhanced ML training pipeline failed",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
|
||||
# Publish training failed event
|
||||
await publish_training_failed(job_id, tenant_id, str(e))
|
||||
@@ -397,7 +401,8 @@ class EnhancedBakeryMLTrainer:
|
||||
logger.error("Single product model training failed",
|
||||
job_id=job_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def _serialize_scalers(self, scalers: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@@ -329,4 +329,17 @@ class TrainingLogRepository(TrainingBaseRepository):
|
||||
"min_duration_minutes": 0.0,
|
||||
"max_duration_minutes": 0.0,
|
||||
"completed_jobs_with_duration": 0
|
||||
}
|
||||
}
|
||||
|
||||
async def get_start_time(self, job_id: str) -> Optional[datetime]:
|
||||
"""Get the start time for a training job"""
|
||||
try:
|
||||
log_entry = await self.get_by_job_id(job_id)
|
||||
if log_entry and log_entry.start_time:
|
||||
return log_entry.start_time
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get start time",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
return None
|
||||
@@ -10,6 +10,11 @@ from datetime import datetime, timezone
|
||||
|
||||
from app.services.training_events import publish_product_training_completed
|
||||
from app.utils.time_estimation import calculate_estimated_completion_time
|
||||
from app.core.training_constants import (
|
||||
PROGRESS_TRAINING_RANGE_START,
|
||||
PROGRESS_TRAINING_RANGE_END,
|
||||
PROGRESS_TRAINING_RANGE_WIDTH
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -34,8 +39,8 @@ class ParallelProductProgressTracker:
|
||||
self.start_time = datetime.now(timezone.utc)
|
||||
|
||||
# Calculate progress increment per product
|
||||
# 60% of total progress (from 20% to 80%) divided by number of products
|
||||
self.progress_per_product = 60 / total_products if total_products > 0 else 0
|
||||
# Training range (from PROGRESS_TRAINING_RANGE_START to PROGRESS_TRAINING_RANGE_END) divided by number of products
|
||||
self.progress_per_product = PROGRESS_TRAINING_RANGE_WIDTH / total_products if total_products > 0 else 0
|
||||
|
||||
logger.info("ParallelProductProgressTracker initialized",
|
||||
job_id=job_id,
|
||||
@@ -80,9 +85,9 @@ class ParallelProductProgressTracker:
|
||||
estimated_completion_time=estimated_completion_time
|
||||
)
|
||||
|
||||
# Calculate overall progress (20% base + progress from completed products)
|
||||
# Calculate overall progress (PROGRESS_TRAINING_RANGE_START% base + progress from completed products)
|
||||
# This calculation is done on the frontend/consumer side based on the event data
|
||||
overall_progress = 20 + int((current_progress / self.total_products) * 60)
|
||||
overall_progress = PROGRESS_TRAINING_RANGE_START + int((current_progress / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
|
||||
|
||||
logger.info("Product training completed",
|
||||
job_id=self.job_id,
|
||||
@@ -99,5 +104,5 @@ class ParallelProductProgressTracker:
|
||||
return {
|
||||
"products_completed": self.products_completed,
|
||||
"total_products": self.total_products,
|
||||
"progress_percentage": 20 + int((self.products_completed / self.total_products) * 60)
|
||||
"progress_percentage": PROGRESS_TRAINING_RANGE_START + int((self.products_completed / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
|
||||
}
|
||||
|
||||
@@ -135,6 +135,61 @@ async def publish_data_analysis(
|
||||
return success
|
||||
|
||||
|
||||
async def publish_training_progress(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
progress: int,
|
||||
current_step: str,
|
||||
step_details: Optional[str] = None,
|
||||
estimated_time_remaining_seconds: Optional[int] = None,
|
||||
estimated_completion_time: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Generic Training Progress Event (for any progress percentage)
|
||||
|
||||
Args:
|
||||
job_id: Training job identifier
|
||||
tenant_id: Tenant identifier
|
||||
progress: Progress percentage (0-100)
|
||||
current_step: Current step name
|
||||
step_details: Details about the current step
|
||||
estimated_time_remaining_seconds: Estimated time remaining in seconds
|
||||
estimated_completion_time: ISO timestamp of estimated completion
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.progress",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"progress": progress,
|
||||
"current_step": current_step,
|
||||
"step_details": step_details or current_step,
|
||||
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
|
||||
"estimated_completion_time": estimated_completion_time
|
||||
}
|
||||
}
|
||||
|
||||
success = await training_publisher.publish_event(
|
||||
exchange_name="training.events",
|
||||
routing_key="training.progress",
|
||||
event_data=event_data
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Published training progress event",
|
||||
job_id=job_id,
|
||||
progress=progress,
|
||||
current_step=current_step)
|
||||
else:
|
||||
logger.error("Failed to publish training progress event",
|
||||
job_id=job_id,
|
||||
progress=progress)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def publish_product_training_completed(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
|
||||
@@ -16,6 +16,15 @@ import pandas as pd
|
||||
from app.ml.trainer import EnhancedBakeryMLTrainer
|
||||
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
|
||||
from app.services.training_orchestrator import TrainingDataOrchestrator
|
||||
from app.core.training_constants import (
|
||||
PROGRESS_DATA_VALIDATION,
|
||||
PROGRESS_DATA_PREPARATION_COMPLETE,
|
||||
PROGRESS_ML_TRAINING_START,
|
||||
PROGRESS_TRAINING_COMPLETE,
|
||||
PROGRESS_STORING_MODELS,
|
||||
PROGRESS_STORING_METRICS,
|
||||
MAX_ESTIMATED_TIME_REMAINING_SECONDS
|
||||
)
|
||||
|
||||
# Import repositories
|
||||
from app.repositories import (
|
||||
@@ -187,7 +196,7 @@ class EnhancedTrainingService:
|
||||
# Step 1: Prepare training dataset (includes sales data validation)
|
||||
logger.info("Step 1: Preparing and aligning training data (with validation)")
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 10, "data_validation", "running"
|
||||
job_id, PROGRESS_DATA_VALIDATION, "data_validation", "running"
|
||||
)
|
||||
|
||||
# Orchestrator now handles sales data validation to eliminate duplicate fetching
|
||||
@@ -204,13 +213,13 @@ class EnhancedTrainingService:
|
||||
tenant_id=tenant_id, job_id=job_id)
|
||||
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 30, "data_preparation_complete", "running"
|
||||
job_id, PROGRESS_DATA_PREPARATION_COMPLETE, "data_preparation_complete", "running"
|
||||
)
|
||||
|
||||
|
||||
# Step 2: Execute ML training pipeline
|
||||
logger.info("Step 2: Starting ML training pipeline")
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 40, "ml_training", "running"
|
||||
job_id, PROGRESS_ML_TRAINING_START, "ml_training", "running"
|
||||
)
|
||||
|
||||
training_results = await self.trainer.train_tenant_models(
|
||||
@@ -220,9 +229,19 @@ class EnhancedTrainingService:
|
||||
)
|
||||
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 85, "training_complete", "running"
|
||||
job_id, PROGRESS_TRAINING_COMPLETE, "training_complete", "running"
|
||||
)
|
||||
|
||||
|
||||
# Publish progress event (85%)
|
||||
from app.services.training_events import publish_training_progress
|
||||
await publish_training_progress(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
progress=PROGRESS_TRAINING_COMPLETE,
|
||||
current_step="Training Complete",
|
||||
step_details="All products trained successfully"
|
||||
)
|
||||
|
||||
# Step 3: Store model records using repository
|
||||
logger.info("Step 3: Storing model records")
|
||||
logger.debug("Training results structure",
|
||||
@@ -234,12 +253,30 @@ class EnhancedTrainingService:
|
||||
)
|
||||
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 92, "storing_models", "running"
|
||||
job_id, PROGRESS_STORING_MODELS, "storing_models", "running"
|
||||
)
|
||||
|
||||
# Publish progress event (92%)
|
||||
await publish_training_progress(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
progress=PROGRESS_STORING_MODELS,
|
||||
current_step="Storing Models",
|
||||
step_details=f"Saved {len(stored_models)} trained models to database"
|
||||
)
|
||||
|
||||
# Step 4: Create performance metrics
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 94, "storing_performance_metrics", "running"
|
||||
job_id, PROGRESS_STORING_METRICS, "storing_performance_metrics", "running"
|
||||
)
|
||||
|
||||
# Publish progress event (94%)
|
||||
await publish_training_progress(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
progress=PROGRESS_STORING_METRICS,
|
||||
current_step="Storing Performance Metrics",
|
||||
step_details="Saving model performance metrics"
|
||||
)
|
||||
|
||||
await self._create_performance_metrics(
|
||||
@@ -316,8 +353,9 @@ class EnhancedTrainingService:
|
||||
except Exception as e:
|
||||
logger.error("Enhanced training job failed",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
|
||||
# Mark as failed in database
|
||||
await self.training_log_repo.complete_training_log(
|
||||
job_id, error_message=str(e)
|
||||
@@ -519,12 +557,13 @@ class EnhancedTrainingService:
|
||||
if log.status == "running" and log.progress > 0 and log.start_time:
|
||||
from datetime import datetime, timezone
|
||||
elapsed_time = (datetime.now(timezone.utc) - log.start_time).total_seconds()
|
||||
if elapsed_time > 0:
|
||||
if elapsed_time > 0 and log.progress > 0: # Double-check progress is positive
|
||||
# Calculate estimated total time based on progress
|
||||
estimated_total_time = (elapsed_time / log.progress) * 100
|
||||
# Use max(log.progress, 1) as additional safety against division by zero
|
||||
estimated_total_time = (elapsed_time / max(log.progress, 1)) * 100
|
||||
estimated_time_remaining_seconds = int(estimated_total_time - elapsed_time)
|
||||
# Cap at reasonable maximum (e.g., 30 minutes)
|
||||
estimated_time_remaining_seconds = min(estimated_time_remaining_seconds, 1800)
|
||||
# Cap at reasonable maximum (30 minutes)
|
||||
estimated_time_remaining_seconds = max(0, min(estimated_time_remaining_seconds, MAX_ESTIMATED_TIME_REMAINING_SECONDS))
|
||||
|
||||
# Extract products info from results if available
|
||||
products_total = 0
|
||||
|
||||
Reference in New Issue
Block a user