79 lines
2.8 KiB
Python
79 lines
2.8 KiB
Python
"""
|
|
Training Progress Tracker
|
|
Manages progress calculation for parallel product training (20-80% range)
|
|
"""
|
|
|
|
import asyncio
|
|
import structlog
|
|
from typing import Optional
|
|
|
|
from app.services.training_events import publish_product_training_completed
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class ParallelProductProgressTracker:
|
|
"""
|
|
Tracks parallel product training progress and emits events.
|
|
|
|
For N products training in parallel:
|
|
- Each product completion contributes 60/N% to overall progress
|
|
- Progress range: 20% (after data analysis) to 80% (before completion)
|
|
- Thread-safe for concurrent product trainings
|
|
"""
|
|
|
|
def __init__(self, job_id: str, tenant_id: str, total_products: int):
|
|
self.job_id = job_id
|
|
self.tenant_id = tenant_id
|
|
self.total_products = total_products
|
|
self.products_completed = 0
|
|
self._lock = asyncio.Lock()
|
|
|
|
# Calculate progress increment per product
|
|
# 60% of total progress (from 20% to 80%) divided by number of products
|
|
self.progress_per_product = 60 / total_products if total_products > 0 else 0
|
|
|
|
logger.info("ParallelProductProgressTracker initialized",
|
|
job_id=job_id,
|
|
total_products=total_products,
|
|
progress_per_product=f"{self.progress_per_product:.2f}%")
|
|
|
|
async def mark_product_completed(self, product_name: str) -> int:
|
|
"""
|
|
Mark a product as completed and publish event.
|
|
Returns the current overall progress percentage.
|
|
"""
|
|
async with self._lock:
|
|
self.products_completed += 1
|
|
current_progress = self.products_completed
|
|
|
|
# Publish product completion event
|
|
await publish_product_training_completed(
|
|
job_id=self.job_id,
|
|
tenant_id=self.tenant_id,
|
|
product_name=product_name,
|
|
products_completed=current_progress,
|
|
total_products=self.total_products
|
|
)
|
|
|
|
# Calculate overall progress (20% base + progress from completed products)
|
|
# This calculation is done on the frontend/consumer side based on the event data
|
|
overall_progress = 20 + int((current_progress / self.total_products) * 60)
|
|
|
|
logger.info("Product training completed",
|
|
job_id=self.job_id,
|
|
product_name=product_name,
|
|
products_completed=current_progress,
|
|
total_products=self.total_products,
|
|
overall_progress=overall_progress)
|
|
|
|
return overall_progress
|
|
|
|
def get_progress(self) -> dict:
|
|
"""Get current progress summary"""
|
|
return {
|
|
"products_completed": self.products_completed,
|
|
"total_products": self.total_products,
|
|
"progress_percentage": 20 + int((self.products_completed / self.total_products) * 60)
|
|
}
|