Fix new Frontend 13
This commit is contained in:
@@ -20,12 +20,7 @@ from app.core.config import settings
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.messaging import (
|
||||
publish_job_progress,
|
||||
publish_data_validation_started,
|
||||
publish_data_validation_completed,
|
||||
publish_job_step_completed
|
||||
)
|
||||
from app.services.messaging import TrainingStatusPublisher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,6 +54,8 @@ class BakeryMLTrainer:
|
||||
|
||||
logger.info(f"Starting ML training pipeline {job_id} for tenant {tenant_id}")
|
||||
|
||||
self.status_publisher = TrainingStatusPublisher(job_id, tenant_id)
|
||||
|
||||
try:
|
||||
# Convert sales data to DataFrame
|
||||
sales_df = pd.DataFrame(training_dataset.sales_data)
|
||||
@@ -72,12 +69,18 @@ class BakeryMLTrainer:
|
||||
products = sales_df['product_name'].unique().tolist()
|
||||
logger.info(f"Training models for {len(products)} products: {products}")
|
||||
|
||||
self.status_publisher.products_total = len(products)
|
||||
|
||||
# Process data for each product
|
||||
logger.info("Processing data for all products...")
|
||||
processed_data = await self._process_all_products(
|
||||
sales_df, weather_df, traffic_df, products
|
||||
)
|
||||
await publish_job_progress(job_id, tenant_id, 20, "feature_engineering", estimated_time_remaining_minutes=7)
|
||||
await self.status_publisher.progress_update(
|
||||
progress=20,
|
||||
step="feature_engineering",
|
||||
step_details="Processing features for all products"
|
||||
)
|
||||
|
||||
# Train models for each processed product
|
||||
logger.info("Training models for all products...")
|
||||
@@ -87,7 +90,11 @@ class BakeryMLTrainer:
|
||||
|
||||
# Calculate overall training summary
|
||||
summary = self._calculate_training_summary(training_results)
|
||||
await publish_job_progress(job_id, tenant_id, 90, "model_validation", estimated_time_remaining_minutes=1)
|
||||
await self.status_publisher.progress_update(
|
||||
progress=90,
|
||||
step="model_validation",
|
||||
step_details="Validating model performance"
|
||||
)
|
||||
|
||||
result = {
|
||||
"job_id": job_id,
|
||||
@@ -399,16 +406,11 @@ class BakeryMLTrainer:
|
||||
job_id: str) -> Dict[str, Any]:
|
||||
"""Train models for all processed products using Prophet manager"""
|
||||
training_results = {}
|
||||
|
||||
i = 0
|
||||
total_products = len(processed_data)
|
||||
base_progress = 45
|
||||
max_progress = 85 # or whatever your target end progress is
|
||||
products_total = 0
|
||||
i = 0
|
||||
max_progress = 85
|
||||
|
||||
start_time = time.time()
|
||||
processing_times = [] # Store individual processing times
|
||||
|
||||
for product_name, product_data in processed_data.items():
|
||||
product_start_time = time.time()
|
||||
try:
|
||||
@@ -424,7 +426,6 @@ class BakeryMLTrainer:
|
||||
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
|
||||
}
|
||||
logger.warning(f"Skipping {product_name}: insufficient data ({len(product_data)} < {settings.MIN_TRAINING_DATA_DAYS})")
|
||||
processing_times.append(time.time() - product_start_time)
|
||||
continue
|
||||
|
||||
# Train the model using Prophet manager
|
||||
@@ -444,6 +445,20 @@ class BakeryMLTrainer:
|
||||
|
||||
logger.info(f"Successfully trained model for {product_name}")
|
||||
|
||||
completed_products = i + 1
|
||||
progress = base_progress + int((completed_products / total_products) * (max_progress - base_progress))
|
||||
|
||||
if self.status_publisher:
|
||||
# Update products completed for accurate tracking
|
||||
self.status_publisher.products_completed = completed_products
|
||||
|
||||
await self.status_publisher.product_completed(
|
||||
progress=progress,
|
||||
step="model_training",
|
||||
current_product=product_name,
|
||||
step_details=f"Completed training for {product_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to train model for {product_name}: {str(e)}")
|
||||
training_results[product_name] = {
|
||||
@@ -452,29 +467,18 @@ class BakeryMLTrainer:
|
||||
'data_points': len(product_data) if product_data is not None else 0,
|
||||
'failed_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
completed_products = i + 1
|
||||
|
||||
if self.status_publisher:
|
||||
self.status_publisher.products_completed = completed_products
|
||||
await self.status_publisher.progress_update(
|
||||
progress=progress,
|
||||
step="model_training",
|
||||
current_product=product_name,
|
||||
step_details=f"Failed training for {product_name}: {str(e)}"
|
||||
)
|
||||
|
||||
# Record processing time for this product
|
||||
product_processing_time = time.time() - product_start_time
|
||||
processing_times.append(product_processing_time)
|
||||
|
||||
i += 1
|
||||
current_progress = base_progress + int((i / total_products) * (max_progress - base_progress))
|
||||
|
||||
# Calculate estimated time remaining
|
||||
estimated_time_remaining_minutes = self.calculate_estimated_time_remaining(
|
||||
processing_times, i, total_products
|
||||
)
|
||||
|
||||
await publish_job_progress(
|
||||
job_id,
|
||||
tenant_id,
|
||||
current_progress,
|
||||
"model_training",
|
||||
product_name,
|
||||
products_total,
|
||||
total_products,
|
||||
estimated_time_remaining_minutes=estimated_time_remaining_minutes
|
||||
)
|
||||
|
||||
return training_results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user