Fix multiple critical bugs in onboarding training step

This commit addresses all identified bugs and issues in the training code path:

## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"

## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes

## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
  - Backend: services/training/app/core/training_constants.py
  - Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors

## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns

## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)

All training progress events now properly flow from 0% to 100% with no gaps.
This commit is contained in:
Claude
2025-11-05 13:02:39 +00:00
parent e3ea92640b
commit 5a84be83d6
10 changed files with 291 additions and 106 deletions

View File

@@ -189,15 +189,8 @@ async def start_training_job(
# Calculate estimated completion time
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
# Publish training.started event immediately so WebSocket clients
# have initial state when they connect
await publish_training_started(
job_id=job_id,
tenant_id=tenant_id,
total_products=0, # Will be updated when actual training starts
estimated_duration_minutes=estimated_duration_minutes,
estimated_completion_time=estimated_completion_time.isoformat()
)
# Note: training.started event will be published by the trainer with accurate product count
# We don't publish here to avoid duplicate events
# Add enhanced background task
background_tasks.add_task(
@@ -401,11 +394,6 @@ async def execute_training_job_background(
# Failure event is published by the training service
await publish_training_failed(job_id, tenant_id, str(training_error))
except Exception as background_error:
logger.error("Critical error in enhanced background training job",
job_id=job_id,
error=str(background_error))
finally:
logger.info("Enhanced background training job cleanup completed",
job_id=job_id)

View File

@@ -0,0 +1,35 @@
"""
Training Progress Constants
Centralized constants for training progress tracking and timing
"""
# Progress Milestones (percentage)
PROGRESS_STARTED = 0
PROGRESS_DATA_VALIDATION = 10
PROGRESS_DATA_ANALYSIS = 20
PROGRESS_DATA_PREPARATION_COMPLETE = 30
PROGRESS_ML_TRAINING_START = 40
PROGRESS_TRAINING_COMPLETE = 85
PROGRESS_STORING_MODELS = 92
PROGRESS_STORING_METRICS = 94
PROGRESS_COMPLETED = 100
# Progress Ranges
PROGRESS_TRAINING_RANGE_START = 20 # After data analysis
PROGRESS_TRAINING_RANGE_END = 80 # Before finalization
PROGRESS_TRAINING_RANGE_WIDTH = PROGRESS_TRAINING_RANGE_END - PROGRESS_TRAINING_RANGE_START # 60%
# Time Limits and Intervals (seconds)
MAX_ESTIMATED_TIME_REMAINING_SECONDS = 1800 # 30 minutes
WEBSOCKET_HEARTBEAT_INTERVAL_SECONDS = 30
WEBSOCKET_RECONNECT_MAX_ATTEMPTS = 3
WEBSOCKET_RECONNECT_INITIAL_DELAY_SECONDS = 1
WEBSOCKET_RECONNECT_MAX_DELAY_SECONDS = 10
# Training Timeouts (seconds)
TRAINING_SKIP_OPTION_DELAY_SECONDS = 120 # 2 minutes
HTTP_POLLING_INTERVAL_MS = 5000 # 5 seconds
HTTP_POLLING_DEBOUNCE_MS = 5000 # 5 seconds before enabling after WebSocket disconnect
# Frontend Display
TRAINING_COMPLETION_DELAY_MS = 2000 # Delay before navigating after completion

View File

@@ -6,7 +6,7 @@ Main ML pipeline coordinator using repository pattern for data access and depend
from typing import Dict, List, Any, Optional
import pandas as pd
import numpy as np
from datetime import datetime
from datetime import datetime, timezone
import structlog
import uuid
import time
@@ -187,7 +187,10 @@ class EnhancedBakeryMLTrainer:
# Event 2: Data Analysis (20%)
# Recalculate time remaining based on elapsed time
elapsed_seconds = (datetime.now(timezone.utc) - repos['training_log']._get_start_time(job_id) if hasattr(repos['training_log'], '_get_start_time') else 0) or 0
start_time = await repos['training_log'].get_start_time(job_id)
elapsed_seconds = 0
if start_time:
elapsed_seconds = int((datetime.now(timezone.utc) - start_time).total_seconds())
# Estimate remaining time: we've done ~20% of work (data analysis)
# Remaining 80% includes training all products
@@ -285,7 +288,8 @@ class EnhancedBakeryMLTrainer:
except Exception as e:
logger.error("Enhanced ML training pipeline failed",
job_id=job_id,
error=str(e))
error=str(e),
exc_info=True)
# Publish training failed event
await publish_training_failed(job_id, tenant_id, str(e))
@@ -397,7 +401,8 @@ class EnhancedBakeryMLTrainer:
logger.error("Single product model training failed",
job_id=job_id,
inventory_product_id=inventory_product_id,
error=str(e))
error=str(e),
exc_info=True)
raise
def _serialize_scalers(self, scalers: Dict[str, Any]) -> Dict[str, Any]:

View File

@@ -329,4 +329,17 @@ class TrainingLogRepository(TrainingBaseRepository):
"min_duration_minutes": 0.0,
"max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0
}
}
async def get_start_time(self, job_id: str) -> Optional[datetime]:
"""Get the start time for a training job"""
try:
log_entry = await self.get_by_job_id(job_id)
if log_entry and log_entry.start_time:
return log_entry.start_time
return None
except Exception as e:
logger.error("Failed to get start time",
job_id=job_id,
error=str(e))
return None

View File

@@ -10,6 +10,11 @@ from datetime import datetime, timezone
from app.services.training_events import publish_product_training_completed
from app.utils.time_estimation import calculate_estimated_completion_time
from app.core.training_constants import (
PROGRESS_TRAINING_RANGE_START,
PROGRESS_TRAINING_RANGE_END,
PROGRESS_TRAINING_RANGE_WIDTH
)
logger = structlog.get_logger()
@@ -34,8 +39,8 @@ class ParallelProductProgressTracker:
self.start_time = datetime.now(timezone.utc)
# Calculate progress increment per product
# 60% of total progress (from 20% to 80%) divided by number of products
self.progress_per_product = 60 / total_products if total_products > 0 else 0
# Training range (from PROGRESS_TRAINING_RANGE_START to PROGRESS_TRAINING_RANGE_END) divided by number of products
self.progress_per_product = PROGRESS_TRAINING_RANGE_WIDTH / total_products if total_products > 0 else 0
logger.info("ParallelProductProgressTracker initialized",
job_id=job_id,
@@ -80,9 +85,9 @@ class ParallelProductProgressTracker:
estimated_completion_time=estimated_completion_time
)
# Calculate overall progress (20% base + progress from completed products)
# Calculate overall progress (PROGRESS_TRAINING_RANGE_START% base + progress from completed products)
# This calculation is done on the frontend/consumer side based on the event data
overall_progress = 20 + int((current_progress / self.total_products) * 60)
overall_progress = PROGRESS_TRAINING_RANGE_START + int((current_progress / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
logger.info("Product training completed",
job_id=self.job_id,
@@ -99,5 +104,5 @@ class ParallelProductProgressTracker:
return {
"products_completed": self.products_completed,
"total_products": self.total_products,
"progress_percentage": 20 + int((self.products_completed / self.total_products) * 60)
"progress_percentage": PROGRESS_TRAINING_RANGE_START + int((self.products_completed / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
}

View File

@@ -135,6 +135,61 @@ async def publish_data_analysis(
return success
async def publish_training_progress(
job_id: str,
tenant_id: str,
progress: int,
current_step: str,
step_details: Optional[str] = None,
estimated_time_remaining_seconds: Optional[int] = None,
estimated_completion_time: Optional[str] = None
) -> bool:
"""
Generic Training Progress Event (for any progress percentage)
Args:
job_id: Training job identifier
tenant_id: Tenant identifier
progress: Progress percentage (0-100)
current_step: Current step name
step_details: Details about the current step
estimated_time_remaining_seconds: Estimated time remaining in seconds
estimated_completion_time: ISO timestamp of estimated completion
"""
event_data = {
"service_name": "training-service",
"event_type": "training.progress",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"progress": progress,
"current_step": current_step,
"step_details": step_details or current_step,
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
"estimated_completion_time": estimated_completion_time
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.progress",
event_data=event_data
)
if success:
logger.info("Published training progress event",
job_id=job_id,
progress=progress,
current_step=current_step)
else:
logger.error("Failed to publish training progress event",
job_id=job_id,
progress=progress)
return success
async def publish_product_training_completed(
job_id: str,
tenant_id: str,

View File

@@ -16,6 +16,15 @@ import pandas as pd
from app.ml.trainer import EnhancedBakeryMLTrainer
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
from app.services.training_orchestrator import TrainingDataOrchestrator
from app.core.training_constants import (
PROGRESS_DATA_VALIDATION,
PROGRESS_DATA_PREPARATION_COMPLETE,
PROGRESS_ML_TRAINING_START,
PROGRESS_TRAINING_COMPLETE,
PROGRESS_STORING_MODELS,
PROGRESS_STORING_METRICS,
MAX_ESTIMATED_TIME_REMAINING_SECONDS
)
# Import repositories
from app.repositories import (
@@ -187,7 +196,7 @@ class EnhancedTrainingService:
# Step 1: Prepare training dataset (includes sales data validation)
logger.info("Step 1: Preparing and aligning training data (with validation)")
await self.training_log_repo.update_log_progress(
job_id, 10, "data_validation", "running"
job_id, PROGRESS_DATA_VALIDATION, "data_validation", "running"
)
# Orchestrator now handles sales data validation to eliminate duplicate fetching
@@ -204,13 +213,13 @@ class EnhancedTrainingService:
tenant_id=tenant_id, job_id=job_id)
await self.training_log_repo.update_log_progress(
job_id, 30, "data_preparation_complete", "running"
job_id, PROGRESS_DATA_PREPARATION_COMPLETE, "data_preparation_complete", "running"
)
# Step 2: Execute ML training pipeline
logger.info("Step 2: Starting ML training pipeline")
await self.training_log_repo.update_log_progress(
job_id, 40, "ml_training", "running"
job_id, PROGRESS_ML_TRAINING_START, "ml_training", "running"
)
training_results = await self.trainer.train_tenant_models(
@@ -220,9 +229,19 @@ class EnhancedTrainingService:
)
await self.training_log_repo.update_log_progress(
job_id, 85, "training_complete", "running"
job_id, PROGRESS_TRAINING_COMPLETE, "training_complete", "running"
)
# Publish progress event (85%)
from app.services.training_events import publish_training_progress
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_TRAINING_COMPLETE,
current_step="Training Complete",
step_details="All products trained successfully"
)
# Step 3: Store model records using repository
logger.info("Step 3: Storing model records")
logger.debug("Training results structure",
@@ -234,12 +253,30 @@ class EnhancedTrainingService:
)
await self.training_log_repo.update_log_progress(
job_id, 92, "storing_models", "running"
job_id, PROGRESS_STORING_MODELS, "storing_models", "running"
)
# Publish progress event (92%)
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_STORING_MODELS,
current_step="Storing Models",
step_details=f"Saved {len(stored_models)} trained models to database"
)
# Step 4: Create performance metrics
await self.training_log_repo.update_log_progress(
job_id, 94, "storing_performance_metrics", "running"
job_id, PROGRESS_STORING_METRICS, "storing_performance_metrics", "running"
)
# Publish progress event (94%)
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_STORING_METRICS,
current_step="Storing Performance Metrics",
step_details="Saving model performance metrics"
)
await self._create_performance_metrics(
@@ -316,8 +353,9 @@ class EnhancedTrainingService:
except Exception as e:
logger.error("Enhanced training job failed",
job_id=job_id,
error=str(e))
error=str(e),
exc_info=True)
# Mark as failed in database
await self.training_log_repo.complete_training_log(
job_id, error_message=str(e)
@@ -519,12 +557,13 @@ class EnhancedTrainingService:
if log.status == "running" and log.progress > 0 and log.start_time:
from datetime import datetime, timezone
elapsed_time = (datetime.now(timezone.utc) - log.start_time).total_seconds()
if elapsed_time > 0:
if elapsed_time > 0 and log.progress > 0: # Double-check progress is positive
# Calculate estimated total time based on progress
estimated_total_time = (elapsed_time / log.progress) * 100
# Use max(log.progress, 1) as additional safety against division by zero
estimated_total_time = (elapsed_time / max(log.progress, 1)) * 100
estimated_time_remaining_seconds = int(estimated_total_time - elapsed_time)
# Cap at reasonable maximum (e.g., 30 minutes)
estimated_time_remaining_seconds = min(estimated_time_remaining_seconds, 1800)
# Cap at reasonable maximum (30 minutes)
estimated_time_remaining_seconds = max(0, min(estimated_time_remaining_seconds, MAX_ESTIMATED_TIME_REMAINING_SECONDS))
# Extract products info from results if available
products_total = 0