Files
bakery-ia/services/training/app/services/progress_tracker.py
Claude 5a84be83d6 Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:

## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"

## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes

## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
  - Backend: services/training/app/core/training_constants.py
  - Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors

## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns

## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)

All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00

109 lines
4.6 KiB
Python

"""
Training Progress Tracker
Manages progress calculation for parallel product training (20-80% range)
"""
import asyncio
import structlog
from typing import Optional
from datetime import datetime, timezone
from app.services.training_events import publish_product_training_completed
from app.utils.time_estimation import calculate_estimated_completion_time
from app.core.training_constants import (
PROGRESS_TRAINING_RANGE_START,
PROGRESS_TRAINING_RANGE_END,
PROGRESS_TRAINING_RANGE_WIDTH
)
logger = structlog.get_logger()
class ParallelProductProgressTracker:
"""
Tracks parallel product training progress and emits events.
For N products training in parallel:
- Each product completion contributes 60/N% to overall progress
- Progress range: 20% (after data analysis) to 80% (before completion)
- Thread-safe for concurrent product trainings
- Calculates time estimates based on elapsed time and progress
"""
def __init__(self, job_id: str, tenant_id: str, total_products: int):
self.job_id = job_id
self.tenant_id = tenant_id
self.total_products = total_products
self.products_completed = 0
self._lock = asyncio.Lock()
self.start_time = datetime.now(timezone.utc)
# Calculate progress increment per product
# Training range (from PROGRESS_TRAINING_RANGE_START to PROGRESS_TRAINING_RANGE_END) divided by number of products
self.progress_per_product = PROGRESS_TRAINING_RANGE_WIDTH / total_products if total_products > 0 else 0
logger.info("ParallelProductProgressTracker initialized",
job_id=job_id,
total_products=total_products,
progress_per_product=f"{self.progress_per_product:.2f}%")
async def mark_product_completed(self, product_name: str) -> int:
"""
Mark a product as completed and publish event with time estimates.
Returns the current overall progress percentage.
"""
async with self._lock:
self.products_completed += 1
current_progress = self.products_completed
# Calculate time estimates based on elapsed time and progress
elapsed_seconds = (datetime.now(timezone.utc) - self.start_time).total_seconds()
products_remaining = self.total_products - current_progress
# Calculate estimated time remaining
# Avg time per product * remaining products
estimated_time_remaining_seconds = None
estimated_completion_time = None
if current_progress > 0 and products_remaining > 0:
avg_time_per_product = elapsed_seconds / current_progress
estimated_time_remaining_seconds = int(avg_time_per_product * products_remaining)
# Calculate estimated completion time
estimated_duration_minutes = estimated_time_remaining_seconds / 60
completion_datetime = calculate_estimated_completion_time(estimated_duration_minutes)
estimated_completion_time = completion_datetime.isoformat()
# Publish product completion event with time estimates
await publish_product_training_completed(
job_id=self.job_id,
tenant_id=self.tenant_id,
product_name=product_name,
products_completed=current_progress,
total_products=self.total_products,
estimated_time_remaining_seconds=estimated_time_remaining_seconds,
estimated_completion_time=estimated_completion_time
)
# Calculate overall progress (PROGRESS_TRAINING_RANGE_START% base + progress from completed products)
# This calculation is done on the frontend/consumer side based on the event data
overall_progress = PROGRESS_TRAINING_RANGE_START + int((current_progress / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
logger.info("Product training completed",
job_id=self.job_id,
product_name=product_name,
products_completed=current_progress,
total_products=self.total_products,
overall_progress=overall_progress,
estimated_time_remaining_seconds=estimated_time_remaining_seconds)
return overall_progress
def get_progress(self) -> dict:
"""Get current progress summary"""
return {
"products_completed": self.products_completed,
"total_products": self.total_products,
"progress_percentage": PROGRESS_TRAINING_RANGE_START + int((self.products_completed / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
}