""" Training Time Estimation Utilities Provides intelligent time estimation for training jobs based on: - Product count - Historical performance data - Current progress and throughput """ from typing import List, Optional from datetime import datetime, timedelta, timezone import structlog from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession logger = structlog.get_logger() def calculate_initial_estimate( total_products: int, avg_training_time_per_product: float = 60.0, # seconds, default 1 min/product data_analysis_overhead: float = 120.0, # seconds, data loading & analysis finalization_overhead: float = 60.0, # seconds, saving models & cleanup min_estimate_minutes: int = 5, max_estimate_minutes: int = 60 ) -> int: """ Calculate realistic initial time estimate for training job. Formula: total_time = data_analysis + (products * avg_time_per_product) + finalization Args: total_products: Number of products to train avg_training_time_per_product: Average time per product in seconds data_analysis_overhead: Time for data loading and analysis in seconds finalization_overhead: Time for saving models and cleanup in seconds min_estimate_minutes: Minimum estimate (prevents unrealistic low values) max_estimate_minutes: Maximum estimate (prevents unrealistic high values) Returns: Estimated duration in minutes Examples: >>> calculate_initial_estimate(1) 4 # 120 + 60 + 60 = 240s = 4min >>> calculate_initial_estimate(5) 8 # 120 + 300 + 60 = 480s = 8min >>> calculate_initial_estimate(10) 13 # 120 + 600 + 60 = 780s = 13min >>> calculate_initial_estimate(20) 23 # 120 + 1200 + 60 = 1380s = 23min >>> calculate_initial_estimate(100) 60 # Capped at max (would be 103 min) """ # Calculate total estimated time in seconds estimated_seconds = ( data_analysis_overhead + (total_products * avg_training_time_per_product) + finalization_overhead ) # Convert to minutes, round up estimated_minutes = int((estimated_seconds / 60) + 0.5) # Apply min/max bounds estimated_minutes = max(min_estimate_minutes, min(max_estimate_minutes, estimated_minutes)) logger.info( "Calculated initial time estimate", total_products=total_products, estimated_seconds=estimated_seconds, estimated_minutes=estimated_minutes, avg_time_per_product=avg_training_time_per_product ) return estimated_minutes def calculate_estimated_completion_time( estimated_duration_minutes: int, start_time: Optional[datetime] = None ) -> datetime: """ Calculate estimated completion timestamp. Args: estimated_duration_minutes: Estimated duration in minutes start_time: Job start time (defaults to now) Returns: Estimated completion datetime (timezone-aware UTC) """ if start_time is None: start_time = datetime.now(timezone.utc) completion_time = start_time + timedelta(minutes=estimated_duration_minutes) return completion_time def calculate_remaining_time_smart( progress: int, elapsed_time: float, products_completed: int, total_products: int, recent_product_times: Optional[List[float]] = None, max_remaining_seconds: int = 1800 # 30 minutes ) -> Optional[int]: """ Calculate remaining time using smart algorithm that considers: - Current progress percentage - Actual throughput (products completed / elapsed time) - Recent performance (weighted moving average) Args: progress: Current progress percentage (0-100) elapsed_time: Time elapsed since job start (seconds) products_completed: Number of products completed total_products: Total number of products recent_product_times: List of recent product training times (seconds) max_remaining_seconds: Maximum remaining time (safety cap) Returns: Estimated remaining time in seconds, or None if can't calculate """ # Job completed or not started if progress >= 100 or progress <= 0: return None # Early stage (0-20%): Use weighted estimate if progress <= 20: # In data analysis phase - estimate based on remaining products remaining_products = total_products - products_completed if recent_product_times and len(recent_product_times) > 0: # Use recent performance if available avg_time_per_product = sum(recent_product_times) / len(recent_product_times) else: # Fallback to default avg_time_per_product = 60.0 # 1 minute per product # Estimate: remaining products * avg time + overhead estimated_remaining = (remaining_products * avg_time_per_product) + 60.0 # +1 min overhead logger.debug( "Early stage estimation", progress=progress, remaining_products=remaining_products, avg_time_per_product=avg_time_per_product, estimated_remaining=estimated_remaining ) # Mid/late stage (21-99%): Use actual throughput else: if products_completed > 0: # Calculate actual time per product from current run actual_time_per_product = elapsed_time / products_completed remaining_products = total_products - products_completed estimated_remaining = remaining_products * actual_time_per_product logger.debug( "Mid/late stage estimation", progress=progress, products_completed=products_completed, total_products=total_products, actual_time_per_product=actual_time_per_product, estimated_remaining=estimated_remaining ) else: # Fallback to linear extrapolation estimated_total = (elapsed_time / progress) * 100 estimated_remaining = estimated_total - elapsed_time logger.debug( "Fallback linear estimation", progress=progress, elapsed_time=elapsed_time, estimated_remaining=estimated_remaining ) # Apply safety cap estimated_remaining = min(estimated_remaining, max_remaining_seconds) return int(estimated_remaining) def calculate_average_product_time( products_completed: int, elapsed_time: float, min_products_threshold: int = 3 ) -> Optional[float]: """ Calculate average time per product from current job progress. Args: products_completed: Number of products completed elapsed_time: Time elapsed since job start (seconds) min_products_threshold: Minimum products needed for reliable calculation Returns: Average time per product in seconds, or None if insufficient data """ if products_completed < min_products_threshold: return None avg_time = elapsed_time / products_completed logger.debug( "Calculated average product time", products_completed=products_completed, elapsed_time=elapsed_time, avg_time=avg_time ) return avg_time def format_time_remaining(seconds: int) -> str: """ Format remaining time in human-readable format. Args: seconds: Time in seconds Returns: Formatted string (e.g., "5 minutes", "1 hour 23 minutes") Examples: >>> format_time_remaining(45) "45 seconds" >>> format_time_remaining(180) "3 minutes" >>> format_time_remaining(5400) "1 hour 30 minutes" """ if seconds < 60: return f"{seconds} seconds" minutes = seconds // 60 remaining_seconds = seconds % 60 if minutes < 60: if remaining_seconds > 0: return f"{minutes} minutes {remaining_seconds} seconds" return f"{minutes} minutes" hours = minutes // 60 remaining_minutes = minutes % 60 if remaining_minutes > 0: return f"{hours} hour{'s' if hours > 1 else ''} {remaining_minutes} minutes" return f"{hours} hour{'s' if hours > 1 else ''}" async def get_historical_average_estimate( db_session: AsyncSession, tenant_id: str, lookback_days: int = 30, limit: int = 10 ) -> Optional[float]: """ Get historical average training time per product for a tenant. This function queries the TrainingPerformanceMetrics table to get recent historical data and calculate an average. Args: db_session: Async database session tenant_id: Tenant UUID lookback_days: How many days back to look limit: Maximum number of historical records to consider Returns: Average time per product in seconds, or None if no historical data """ try: from app.models.training import TrainingPerformanceMetrics from datetime import timedelta cutoff = datetime.now(timezone.utc) - timedelta(days=lookback_days) # Query recent training performance metrics using SQLAlchemy 2.0 async pattern query = ( select(TrainingPerformanceMetrics) .where( TrainingPerformanceMetrics.tenant_id == tenant_id, TrainingPerformanceMetrics.completed_at >= cutoff ) .order_by(TrainingPerformanceMetrics.completed_at.desc()) .limit(limit) ) result = await db_session.execute(query) metrics = result.scalars().all() if not metrics: logger.info( "No historical training data found", tenant_id=tenant_id, lookback_days=lookback_days ) return None # Calculate weighted average (more recent = higher weight) total_weight = 0 weighted_sum = 0 for i, metric in enumerate(metrics): # Weight: newer records get higher weight weight = limit - i weighted_sum += metric.avg_time_per_product * weight total_weight += weight if total_weight == 0: return None weighted_avg = weighted_sum / total_weight logger.info( "Calculated historical average", tenant_id=tenant_id, records_used=len(metrics), weighted_avg=weighted_avg ) return weighted_avg except Exception as e: logger.error( "Error getting historical average", tenant_id=tenant_id, error=str(e) ) return None