Improve the sales import
This commit is contained in:
@@ -165,14 +165,6 @@ async def start_training_job(
|
||||
if metrics:
|
||||
metrics.increment_counter("enhanced_training_jobs_created_total")
|
||||
|
||||
# Publish training.started event immediately so WebSocket clients
|
||||
# have initial state when they connect
|
||||
await publish_training_started(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=0 # Will be updated when actual training starts
|
||||
)
|
||||
|
||||
# Calculate intelligent time estimate
|
||||
# We don't know exact product count yet, so use historical average or estimate
|
||||
try:
|
||||
@@ -192,6 +184,19 @@ async def start_training_job(
|
||||
error=str(est_error))
|
||||
estimated_duration_minutes = 15 # Default fallback
|
||||
|
||||
# Calculate estimated completion time
|
||||
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
|
||||
|
||||
# Publish training.started event immediately so WebSocket clients
|
||||
# have initial state when they connect
|
||||
await publish_training_started(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=0, # Will be updated when actual training starts
|
||||
estimated_duration_minutes=estimated_duration_minutes,
|
||||
estimated_completion_time=estimated_completion_time.isoformat()
|
||||
)
|
||||
|
||||
# Add enhanced background task
|
||||
background_tasks.add_task(
|
||||
execute_training_job_background,
|
||||
@@ -362,15 +367,8 @@ async def execute_training_job_background(
|
||||
requested_end=requested_end
|
||||
)
|
||||
|
||||
# Update final status using repository pattern
|
||||
await enhanced_training_service._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
current_step="Enhanced training completed successfully",
|
||||
results=result,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
# Note: Final status is already updated by start_training_job() via complete_training_log()
|
||||
# No need for redundant update here - it was causing duplicate log entries
|
||||
|
||||
# Completion event is published by the training service
|
||||
|
||||
|
||||
@@ -138,14 +138,14 @@ class DataClient:
|
||||
self._fetch_sales_data_internal,
|
||||
tenant_id, start_date, end_date, product_id, fetch_all
|
||||
)
|
||||
except CircuitBreakerError as e:
|
||||
logger.error(f"Sales service circuit breaker open: {e}")
|
||||
raise RuntimeError(f"Sales service unavailable: {str(e)}")
|
||||
except CircuitBreakerError as exc:
|
||||
logger.error("Sales service circuit breaker open", error_message=str(exc))
|
||||
raise RuntimeError(f"Sales service unavailable: {str(exc)}")
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching sales data: {e}", tenant_id=tenant_id)
|
||||
raise RuntimeError(f"Failed to fetch sales data: {str(e)}")
|
||||
except Exception as exc:
|
||||
logger.error("Error fetching sales data", tenant_id=tenant_id, error_message=str(exc))
|
||||
raise RuntimeError(f"Failed to fetch sales data: {str(exc)}")
|
||||
|
||||
async def fetch_weather_data(
|
||||
self,
|
||||
@@ -176,8 +176,8 @@ class DataClient:
|
||||
logger.warning("No weather data returned, will use synthetic data", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error fetching weather data, will use synthetic data: {e}", tenant_id=tenant_id)
|
||||
except Exception as exc:
|
||||
logger.warning("Error fetching weather data, will use synthetic data", tenant_id=tenant_id, error_message=str(exc))
|
||||
return []
|
||||
|
||||
async def fetch_traffic_data_unified(
|
||||
@@ -254,9 +254,9 @@ class DataClient:
|
||||
logger.warning("No fresh traffic data available", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unified traffic data fetch: {e}",
|
||||
tenant_id=tenant_id, cache_key=cache_key)
|
||||
except Exception as exc:
|
||||
logger.error("Error in unified traffic data fetch",
|
||||
tenant_id=tenant_id, cache_key=cache_key, error_message=str(exc))
|
||||
return []
|
||||
|
||||
# Legacy methods for backward compatibility - now delegate to unified method
|
||||
@@ -405,9 +405,9 @@ class DataClient:
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating data: {e}", tenant_id=tenant_id)
|
||||
raise ValueError(f"Data validation failed: {str(e)}")
|
||||
except Exception as exc:
|
||||
logger.error("Error validating data", tenant_id=tenant_id, error_message=str(exc))
|
||||
raise ValueError(f"Data validation failed: {str(exc)}")
|
||||
|
||||
# Global instance - same as before, but much simpler implementation
|
||||
data_client = DataClient()
|
||||
@@ -6,8 +6,10 @@ Manages progress calculation for parallel product training (20-80% range)
|
||||
import asyncio
|
||||
import structlog
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.services.training_events import publish_product_training_completed
|
||||
from app.utils.time_estimation import calculate_estimated_completion_time
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -20,6 +22,7 @@ class ParallelProductProgressTracker:
|
||||
- Each product completion contributes 60/N% to overall progress
|
||||
- Progress range: 20% (after data analysis) to 80% (before completion)
|
||||
- Thread-safe for concurrent product trainings
|
||||
- Calculates time estimates based on elapsed time and progress
|
||||
"""
|
||||
|
||||
def __init__(self, job_id: str, tenant_id: str, total_products: int):
|
||||
@@ -28,6 +31,7 @@ class ParallelProductProgressTracker:
|
||||
self.total_products = total_products
|
||||
self.products_completed = 0
|
||||
self._lock = asyncio.Lock()
|
||||
self.start_time = datetime.now(timezone.utc)
|
||||
|
||||
# Calculate progress increment per product
|
||||
# 60% of total progress (from 20% to 80%) divided by number of products
|
||||
@@ -40,20 +44,40 @@ class ParallelProductProgressTracker:
|
||||
|
||||
async def mark_product_completed(self, product_name: str) -> int:
|
||||
"""
|
||||
Mark a product as completed and publish event.
|
||||
Mark a product as completed and publish event with time estimates.
|
||||
Returns the current overall progress percentage.
|
||||
"""
|
||||
async with self._lock:
|
||||
self.products_completed += 1
|
||||
current_progress = self.products_completed
|
||||
|
||||
# Publish product completion event
|
||||
# Calculate time estimates based on elapsed time and progress
|
||||
elapsed_seconds = (datetime.now(timezone.utc) - self.start_time).total_seconds()
|
||||
products_remaining = self.total_products - current_progress
|
||||
|
||||
# Calculate estimated time remaining
|
||||
# Avg time per product * remaining products
|
||||
estimated_time_remaining_seconds = None
|
||||
estimated_completion_time = None
|
||||
|
||||
if current_progress > 0 and products_remaining > 0:
|
||||
avg_time_per_product = elapsed_seconds / current_progress
|
||||
estimated_time_remaining_seconds = int(avg_time_per_product * products_remaining)
|
||||
|
||||
# Calculate estimated completion time
|
||||
estimated_duration_minutes = estimated_time_remaining_seconds / 60
|
||||
completion_datetime = calculate_estimated_completion_time(estimated_duration_minutes)
|
||||
estimated_completion_time = completion_datetime.isoformat()
|
||||
|
||||
# Publish product completion event with time estimates
|
||||
await publish_product_training_completed(
|
||||
job_id=self.job_id,
|
||||
tenant_id=self.tenant_id,
|
||||
product_name=product_name,
|
||||
products_completed=current_progress,
|
||||
total_products=self.total_products
|
||||
total_products=self.total_products,
|
||||
estimated_time_remaining_seconds=estimated_time_remaining_seconds,
|
||||
estimated_completion_time=estimated_completion_time
|
||||
)
|
||||
|
||||
# Calculate overall progress (20% base + progress from completed products)
|
||||
@@ -65,7 +89,8 @@ class ParallelProductProgressTracker:
|
||||
product_name=product_name,
|
||||
products_completed=current_progress,
|
||||
total_products=self.total_products,
|
||||
overall_progress=overall_progress)
|
||||
overall_progress=overall_progress,
|
||||
estimated_time_remaining_seconds=estimated_time_remaining_seconds)
|
||||
|
||||
return overall_progress
|
||||
|
||||
|
||||
@@ -91,7 +91,8 @@ async def publish_data_analysis(
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
analysis_details: Optional[str] = None,
|
||||
estimated_time_remaining_seconds: Optional[int] = None
|
||||
estimated_time_remaining_seconds: Optional[int] = None,
|
||||
estimated_completion_time: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Event 2: Data Analysis (20% progress)
|
||||
@@ -101,6 +102,7 @@ async def publish_data_analysis(
|
||||
tenant_id: Tenant identifier
|
||||
analysis_details: Details about the analysis
|
||||
estimated_time_remaining_seconds: Estimated time remaining in seconds
|
||||
estimated_completion_time: ISO timestamp of estimated completion
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
@@ -112,7 +114,8 @@ async def publish_data_analysis(
|
||||
"progress": 20,
|
||||
"current_step": "Data Analysis",
|
||||
"step_details": analysis_details or "Analyzing sales, weather, and traffic data",
|
||||
"estimated_time_remaining_seconds": estimated_time_remaining_seconds
|
||||
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
|
||||
"estimated_completion_time": estimated_completion_time
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,7 +141,8 @@ async def publish_product_training_completed(
|
||||
product_name: str,
|
||||
products_completed: int,
|
||||
total_products: int,
|
||||
estimated_time_remaining_seconds: Optional[int] = None
|
||||
estimated_time_remaining_seconds: Optional[int] = None,
|
||||
estimated_completion_time: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Event 3: Product Training Completed (contributes to 20-80% progress)
|
||||
@@ -154,6 +158,7 @@ async def publish_product_training_completed(
|
||||
products_completed: Number of products completed so far
|
||||
total_products: Total number of products
|
||||
estimated_time_remaining_seconds: Estimated time remaining in seconds
|
||||
estimated_completion_time: ISO timestamp of estimated completion
|
||||
"""
|
||||
event_data = {
|
||||
"service_name": "training-service",
|
||||
@@ -167,7 +172,8 @@ async def publish_product_training_completed(
|
||||
"total_products": total_products,
|
||||
"current_step": "Model Training",
|
||||
"step_details": f"Completed training for {product_name} ({products_completed}/{total_products})",
|
||||
"estimated_time_remaining_seconds": estimated_time_remaining_seconds
|
||||
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
|
||||
"estimated_completion_time": estimated_completion_time
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -238,11 +238,19 @@ class EnhancedTrainingService:
|
||||
)
|
||||
|
||||
# Step 4: Create performance metrics
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 94, "storing_performance_metrics", "running"
|
||||
)
|
||||
|
||||
await self._create_performance_metrics(
|
||||
tenant_id, stored_models, training_results
|
||||
)
|
||||
|
||||
# Step 4.5: Save training performance metrics for future estimations
|
||||
await self._save_training_performance_metrics(
|
||||
tenant_id, job_id, training_results, training_log
|
||||
)
|
||||
|
||||
# Step 5: Complete training log
|
||||
final_result = {
|
||||
"job_id": job_id,
|
||||
@@ -426,7 +434,7 @@ class EnhancedTrainingService:
|
||||
model_result = training_results.get("models_trained", {}).get(str(model.inventory_product_id))
|
||||
if model_result and model_result.get("metrics"):
|
||||
metrics = model_result["metrics"]
|
||||
|
||||
|
||||
metric_data = {
|
||||
"model_id": str(model.id),
|
||||
"tenant_id": tenant_id,
|
||||
@@ -439,13 +447,84 @@ class EnhancedTrainingService:
|
||||
"accuracy_percentage": metrics.get("accuracy_percentage", 100 - metrics.get("mape", 0)),
|
||||
"evaluation_samples": model.training_samples
|
||||
}
|
||||
|
||||
|
||||
await self.performance_repo.create_performance_metric(metric_data)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create performance metrics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
|
||||
async def _save_training_performance_metrics(
|
||||
self,
|
||||
tenant_id: str,
|
||||
job_id: str,
|
||||
training_results: Dict[str, Any],
|
||||
training_log
|
||||
):
|
||||
"""
|
||||
Save aggregated training performance metrics for time estimation.
|
||||
This data is used to predict future training durations.
|
||||
"""
|
||||
try:
|
||||
from app.models.training import TrainingPerformanceMetrics
|
||||
|
||||
# Extract timing and success data
|
||||
models_trained = training_results.get("models_trained", {})
|
||||
total_products = len(models_trained)
|
||||
successful_products = sum(1 for m in models_trained.values() if m.get("status") == "completed")
|
||||
failed_products = total_products - successful_products
|
||||
|
||||
# Calculate total duration
|
||||
if training_log.start_time and training_log.end_time:
|
||||
total_duration_seconds = (training_log.end_time - training_log.start_time).total_seconds()
|
||||
else:
|
||||
# Fallback to elapsed time
|
||||
total_duration_seconds = training_results.get("total_training_time", 0)
|
||||
|
||||
# Calculate average time per product
|
||||
if successful_products > 0:
|
||||
avg_time_per_product = total_duration_seconds / successful_products
|
||||
else:
|
||||
avg_time_per_product = 0
|
||||
|
||||
# Extract timing breakdown if available
|
||||
data_analysis_time = training_results.get("data_analysis_time_seconds")
|
||||
training_time = training_results.get("training_time_seconds")
|
||||
finalization_time = training_results.get("finalization_time_seconds")
|
||||
|
||||
# Create performance metrics record
|
||||
metric_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"job_id": job_id,
|
||||
"total_products": total_products,
|
||||
"successful_products": successful_products,
|
||||
"failed_products": failed_products,
|
||||
"total_duration_seconds": total_duration_seconds,
|
||||
"avg_time_per_product": avg_time_per_product,
|
||||
"data_analysis_time_seconds": data_analysis_time,
|
||||
"training_time_seconds": training_time,
|
||||
"finalization_time_seconds": finalization_time,
|
||||
"completed_at": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
# Use repository to create record
|
||||
performance_metrics = TrainingPerformanceMetrics(**metric_data)
|
||||
self.session.add(performance_metrics)
|
||||
await self.session.commit()
|
||||
|
||||
logger.info("Saved training performance metrics for future estimations",
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
avg_time_per_product=avg_time_per_product,
|
||||
total_products=total_products,
|
||||
successful_products=successful_products)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to save training performance metrics",
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
|
||||
async def get_training_status(self, job_id: str) -> Dict[str, Any]:
|
||||
"""Get training job status using repository"""
|
||||
|
||||
Reference in New Issue
Block a user