Fix multiple critical bugs in onboarding training step

This commit addresses all identified bugs and issues in the training code path:

## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"

## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes

## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
  - Backend: services/training/app/core/training_constants.py
  - Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors

## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns

## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)

All training progress events now properly flow from 0% to 100% with no gaps.
This commit is contained in:
Claude
2025-11-05 13:02:39 +00:00
parent e3ea92640b
commit 5a84be83d6
10 changed files with 291 additions and 106 deletions

View File

@@ -10,6 +10,11 @@ from datetime import datetime, timezone
from app.services.training_events import publish_product_training_completed
from app.utils.time_estimation import calculate_estimated_completion_time
from app.core.training_constants import (
PROGRESS_TRAINING_RANGE_START,
PROGRESS_TRAINING_RANGE_END,
PROGRESS_TRAINING_RANGE_WIDTH
)
logger = structlog.get_logger()
@@ -34,8 +39,8 @@ class ParallelProductProgressTracker:
self.start_time = datetime.now(timezone.utc)
# Calculate progress increment per product
# 60% of total progress (from 20% to 80%) divided by number of products
self.progress_per_product = 60 / total_products if total_products > 0 else 0
# Training range (from PROGRESS_TRAINING_RANGE_START to PROGRESS_TRAINING_RANGE_END) divided by number of products
self.progress_per_product = PROGRESS_TRAINING_RANGE_WIDTH / total_products if total_products > 0 else 0
logger.info("ParallelProductProgressTracker initialized",
job_id=job_id,
@@ -80,9 +85,9 @@ class ParallelProductProgressTracker:
estimated_completion_time=estimated_completion_time
)
# Calculate overall progress (20% base + progress from completed products)
# Calculate overall progress (PROGRESS_TRAINING_RANGE_START% base + progress from completed products)
# This calculation is done on the frontend/consumer side based on the event data
overall_progress = 20 + int((current_progress / self.total_products) * 60)
overall_progress = PROGRESS_TRAINING_RANGE_START + int((current_progress / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
logger.info("Product training completed",
job_id=self.job_id,
@@ -99,5 +104,5 @@ class ParallelProductProgressTracker:
return {
"products_completed": self.products_completed,
"total_products": self.total_products,
"progress_percentage": 20 + int((self.products_completed / self.total_products) * 60)
"progress_percentage": PROGRESS_TRAINING_RANGE_START + int((self.products_completed / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
}

View File

@@ -135,6 +135,61 @@ async def publish_data_analysis(
return success
async def publish_training_progress(
job_id: str,
tenant_id: str,
progress: int,
current_step: str,
step_details: Optional[str] = None,
estimated_time_remaining_seconds: Optional[int] = None,
estimated_completion_time: Optional[str] = None
) -> bool:
"""
Generic Training Progress Event (for any progress percentage)
Args:
job_id: Training job identifier
tenant_id: Tenant identifier
progress: Progress percentage (0-100)
current_step: Current step name
step_details: Details about the current step
estimated_time_remaining_seconds: Estimated time remaining in seconds
estimated_completion_time: ISO timestamp of estimated completion
"""
event_data = {
"service_name": "training-service",
"event_type": "training.progress",
"timestamp": datetime.now().isoformat(),
"data": {
"job_id": job_id,
"tenant_id": tenant_id,
"progress": progress,
"current_step": current_step,
"step_details": step_details or current_step,
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
"estimated_completion_time": estimated_completion_time
}
}
success = await training_publisher.publish_event(
exchange_name="training.events",
routing_key="training.progress",
event_data=event_data
)
if success:
logger.info("Published training progress event",
job_id=job_id,
progress=progress,
current_step=current_step)
else:
logger.error("Failed to publish training progress event",
job_id=job_id,
progress=progress)
return success
async def publish_product_training_completed(
job_id: str,
tenant_id: str,

View File

@@ -16,6 +16,15 @@ import pandas as pd
from app.ml.trainer import EnhancedBakeryMLTrainer
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
from app.services.training_orchestrator import TrainingDataOrchestrator
from app.core.training_constants import (
PROGRESS_DATA_VALIDATION,
PROGRESS_DATA_PREPARATION_COMPLETE,
PROGRESS_ML_TRAINING_START,
PROGRESS_TRAINING_COMPLETE,
PROGRESS_STORING_MODELS,
PROGRESS_STORING_METRICS,
MAX_ESTIMATED_TIME_REMAINING_SECONDS
)
# Import repositories
from app.repositories import (
@@ -187,7 +196,7 @@ class EnhancedTrainingService:
# Step 1: Prepare training dataset (includes sales data validation)
logger.info("Step 1: Preparing and aligning training data (with validation)")
await self.training_log_repo.update_log_progress(
job_id, 10, "data_validation", "running"
job_id, PROGRESS_DATA_VALIDATION, "data_validation", "running"
)
# Orchestrator now handles sales data validation to eliminate duplicate fetching
@@ -204,13 +213,13 @@ class EnhancedTrainingService:
tenant_id=tenant_id, job_id=job_id)
await self.training_log_repo.update_log_progress(
job_id, 30, "data_preparation_complete", "running"
job_id, PROGRESS_DATA_PREPARATION_COMPLETE, "data_preparation_complete", "running"
)
# Step 2: Execute ML training pipeline
logger.info("Step 2: Starting ML training pipeline")
await self.training_log_repo.update_log_progress(
job_id, 40, "ml_training", "running"
job_id, PROGRESS_ML_TRAINING_START, "ml_training", "running"
)
training_results = await self.trainer.train_tenant_models(
@@ -220,9 +229,19 @@ class EnhancedTrainingService:
)
await self.training_log_repo.update_log_progress(
job_id, 85, "training_complete", "running"
job_id, PROGRESS_TRAINING_COMPLETE, "training_complete", "running"
)
# Publish progress event (85%)
from app.services.training_events import publish_training_progress
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_TRAINING_COMPLETE,
current_step="Training Complete",
step_details="All products trained successfully"
)
# Step 3: Store model records using repository
logger.info("Step 3: Storing model records")
logger.debug("Training results structure",
@@ -234,12 +253,30 @@ class EnhancedTrainingService:
)
await self.training_log_repo.update_log_progress(
job_id, 92, "storing_models", "running"
job_id, PROGRESS_STORING_MODELS, "storing_models", "running"
)
# Publish progress event (92%)
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_STORING_MODELS,
current_step="Storing Models",
step_details=f"Saved {len(stored_models)} trained models to database"
)
# Step 4: Create performance metrics
await self.training_log_repo.update_log_progress(
job_id, 94, "storing_performance_metrics", "running"
job_id, PROGRESS_STORING_METRICS, "storing_performance_metrics", "running"
)
# Publish progress event (94%)
await publish_training_progress(
job_id=job_id,
tenant_id=tenant_id,
progress=PROGRESS_STORING_METRICS,
current_step="Storing Performance Metrics",
step_details="Saving model performance metrics"
)
await self._create_performance_metrics(
@@ -316,8 +353,9 @@ class EnhancedTrainingService:
except Exception as e:
logger.error("Enhanced training job failed",
job_id=job_id,
error=str(e))
error=str(e),
exc_info=True)
# Mark as failed in database
await self.training_log_repo.complete_training_log(
job_id, error_message=str(e)
@@ -519,12 +557,13 @@ class EnhancedTrainingService:
if log.status == "running" and log.progress > 0 and log.start_time:
from datetime import datetime, timezone
elapsed_time = (datetime.now(timezone.utc) - log.start_time).total_seconds()
if elapsed_time > 0:
if elapsed_time > 0 and log.progress > 0: # Double-check progress is positive
# Calculate estimated total time based on progress
estimated_total_time = (elapsed_time / log.progress) * 100
# Use max(log.progress, 1) as additional safety against division by zero
estimated_total_time = (elapsed_time / max(log.progress, 1)) * 100
estimated_time_remaining_seconds = int(estimated_total_time - elapsed_time)
# Cap at reasonable maximum (e.g., 30 minutes)
estimated_time_remaining_seconds = min(estimated_time_remaining_seconds, 1800)
# Cap at reasonable maximum (30 minutes)
estimated_time_remaining_seconds = max(0, min(estimated_time_remaining_seconds, MAX_ESTIMATED_TIME_REMAINING_SECONDS))
# Extract products info from results if available
products_total = 0