Files
bakery-ia/services/training/app/utils/time_estimation.py
2025-10-15 16:12:49 +02:00

333 lines
10 KiB
Python

"""
Training Time Estimation Utilities
Provides intelligent time estimation for training jobs based on:
- Product count
- Historical performance data
- Current progress and throughput
"""
from typing import List, Optional
from datetime import datetime, timedelta, timezone
import structlog
logger = structlog.get_logger()
def calculate_initial_estimate(
total_products: int,
avg_training_time_per_product: float = 60.0, # seconds, default 1 min/product
data_analysis_overhead: float = 120.0, # seconds, data loading & analysis
finalization_overhead: float = 60.0, # seconds, saving models & cleanup
min_estimate_minutes: int = 5,
max_estimate_minutes: int = 60
) -> int:
"""
Calculate realistic initial time estimate for training job.
Formula:
total_time = data_analysis + (products * avg_time_per_product) + finalization
Args:
total_products: Number of products to train
avg_training_time_per_product: Average time per product in seconds
data_analysis_overhead: Time for data loading and analysis in seconds
finalization_overhead: Time for saving models and cleanup in seconds
min_estimate_minutes: Minimum estimate (prevents unrealistic low values)
max_estimate_minutes: Maximum estimate (prevents unrealistic high values)
Returns:
Estimated duration in minutes
Examples:
>>> calculate_initial_estimate(1)
4 # 120 + 60 + 60 = 240s = 4min
>>> calculate_initial_estimate(5)
8 # 120 + 300 + 60 = 480s = 8min
>>> calculate_initial_estimate(10)
13 # 120 + 600 + 60 = 780s = 13min
>>> calculate_initial_estimate(20)
23 # 120 + 1200 + 60 = 1380s = 23min
>>> calculate_initial_estimate(100)
60 # Capped at max (would be 103 min)
"""
# Calculate total estimated time in seconds
estimated_seconds = (
data_analysis_overhead +
(total_products * avg_training_time_per_product) +
finalization_overhead
)
# Convert to minutes, round up
estimated_minutes = int((estimated_seconds / 60) + 0.5)
# Apply min/max bounds
estimated_minutes = max(min_estimate_minutes, min(max_estimate_minutes, estimated_minutes))
logger.info(
"Calculated initial time estimate",
total_products=total_products,
estimated_seconds=estimated_seconds,
estimated_minutes=estimated_minutes,
avg_time_per_product=avg_training_time_per_product
)
return estimated_minutes
def calculate_estimated_completion_time(
estimated_duration_minutes: int,
start_time: Optional[datetime] = None
) -> datetime:
"""
Calculate estimated completion timestamp.
Args:
estimated_duration_minutes: Estimated duration in minutes
start_time: Job start time (defaults to now)
Returns:
Estimated completion datetime (timezone-aware UTC)
"""
if start_time is None:
start_time = datetime.now(timezone.utc)
completion_time = start_time + timedelta(minutes=estimated_duration_minutes)
return completion_time
def calculate_remaining_time_smart(
progress: int,
elapsed_time: float,
products_completed: int,
total_products: int,
recent_product_times: Optional[List[float]] = None,
max_remaining_seconds: int = 1800 # 30 minutes
) -> Optional[int]:
"""
Calculate remaining time using smart algorithm that considers:
- Current progress percentage
- Actual throughput (products completed / elapsed time)
- Recent performance (weighted moving average)
Args:
progress: Current progress percentage (0-100)
elapsed_time: Time elapsed since job start (seconds)
products_completed: Number of products completed
total_products: Total number of products
recent_product_times: List of recent product training times (seconds)
max_remaining_seconds: Maximum remaining time (safety cap)
Returns:
Estimated remaining time in seconds, or None if can't calculate
"""
# Job completed or not started
if progress >= 100 or progress <= 0:
return None
# Early stage (0-20%): Use weighted estimate
if progress <= 20:
# In data analysis phase - estimate based on remaining products
remaining_products = total_products - products_completed
if recent_product_times and len(recent_product_times) > 0:
# Use recent performance if available
avg_time_per_product = sum(recent_product_times) / len(recent_product_times)
else:
# Fallback to default
avg_time_per_product = 60.0 # 1 minute per product
# Estimate: remaining products * avg time + overhead
estimated_remaining = (remaining_products * avg_time_per_product) + 60.0 # +1 min overhead
logger.debug(
"Early stage estimation",
progress=progress,
remaining_products=remaining_products,
avg_time_per_product=avg_time_per_product,
estimated_remaining=estimated_remaining
)
# Mid/late stage (21-99%): Use actual throughput
else:
if products_completed > 0:
# Calculate actual time per product from current run
actual_time_per_product = elapsed_time / products_completed
remaining_products = total_products - products_completed
estimated_remaining = remaining_products * actual_time_per_product
logger.debug(
"Mid/late stage estimation",
progress=progress,
products_completed=products_completed,
total_products=total_products,
actual_time_per_product=actual_time_per_product,
estimated_remaining=estimated_remaining
)
else:
# Fallback to linear extrapolation
estimated_total = (elapsed_time / progress) * 100
estimated_remaining = estimated_total - elapsed_time
logger.debug(
"Fallback linear estimation",
progress=progress,
elapsed_time=elapsed_time,
estimated_remaining=estimated_remaining
)
# Apply safety cap
estimated_remaining = min(estimated_remaining, max_remaining_seconds)
return int(estimated_remaining)
def calculate_average_product_time(
products_completed: int,
elapsed_time: float,
min_products_threshold: int = 3
) -> Optional[float]:
"""
Calculate average time per product from current job progress.
Args:
products_completed: Number of products completed
elapsed_time: Time elapsed since job start (seconds)
min_products_threshold: Minimum products needed for reliable calculation
Returns:
Average time per product in seconds, or None if insufficient data
"""
if products_completed < min_products_threshold:
return None
avg_time = elapsed_time / products_completed
logger.debug(
"Calculated average product time",
products_completed=products_completed,
elapsed_time=elapsed_time,
avg_time=avg_time
)
return avg_time
def format_time_remaining(seconds: int) -> str:
"""
Format remaining time in human-readable format.
Args:
seconds: Time in seconds
Returns:
Formatted string (e.g., "5 minutes", "1 hour 23 minutes")
Examples:
>>> format_time_remaining(45)
"45 seconds"
>>> format_time_remaining(180)
"3 minutes"
>>> format_time_remaining(5400)
"1 hour 30 minutes"
"""
if seconds < 60:
return f"{seconds} seconds"
minutes = seconds // 60
remaining_seconds = seconds % 60
if minutes < 60:
if remaining_seconds > 0:
return f"{minutes} minutes {remaining_seconds} seconds"
return f"{minutes} minutes"
hours = minutes // 60
remaining_minutes = minutes % 60
if remaining_minutes > 0:
return f"{hours} hour{'s' if hours > 1 else ''} {remaining_minutes} minutes"
return f"{hours} hour{'s' if hours > 1 else ''}"
def get_historical_average_estimate(
db_session,
tenant_id: str,
lookback_days: int = 30,
limit: int = 10
) -> Optional[float]:
"""
Get historical average training time per product for a tenant.
This function queries the TrainingPerformanceMetrics table to get
recent historical data and calculate an average.
Args:
db_session: Database session
tenant_id: Tenant UUID
lookback_days: How many days back to look
limit: Maximum number of historical records to consider
Returns:
Average time per product in seconds, or None if no historical data
"""
try:
from app.models.training import TrainingPerformanceMetrics
from datetime import timedelta
cutoff = datetime.now(timezone.utc) - timedelta(days=lookback_days)
# Query recent training performance metrics
metrics = db_session.query(TrainingPerformanceMetrics).filter(
TrainingPerformanceMetrics.tenant_id == tenant_id,
TrainingPerformanceMetrics.completed_at >= cutoff
).order_by(
TrainingPerformanceMetrics.completed_at.desc()
).limit(limit).all()
if not metrics:
logger.info(
"No historical training data found",
tenant_id=tenant_id,
lookback_days=lookback_days
)
return None
# Calculate weighted average (more recent = higher weight)
total_weight = 0
weighted_sum = 0
for i, metric in enumerate(metrics):
# Weight: newer records get higher weight
weight = limit - i
weighted_sum += metric.avg_time_per_product * weight
total_weight += weight
if total_weight == 0:
return None
weighted_avg = weighted_sum / total_weight
logger.info(
"Calculated historical average",
tenant_id=tenant_id,
records_used=len(metrics),
weighted_avg=weighted_avg
)
return weighted_avg
except Exception as e:
logger.error(
"Error getting historical average",
tenant_id=tenant_id,
error=str(e)
)
return None