Initial commit - production deployment

This commit is contained in:
2026-01-21 17:17:16 +01:00
commit c23d00dd92
2289 changed files with 638440 additions and 0 deletions

View File

@@ -0,0 +1,92 @@
"""
Training Service Utilities
"""
from .ml_datetime import (
ensure_timezone_aware,
ensure_timezone_naive,
normalize_datetime_to_utc,
normalize_dataframe_datetime_column,
prepare_prophet_datetime,
safe_datetime_comparison,
get_current_utc,
convert_timestamp_to_datetime
)
from .circuit_breaker import (
CircuitBreaker,
CircuitBreakerError,
CircuitState,
circuit_breaker_registry
)
from .file_utils import (
calculate_file_checksum,
verify_file_checksum,
get_file_size,
ensure_directory_exists,
safe_file_delete,
get_file_metadata,
ChecksummedFile
)
from .distributed_lock import (
DatabaseLock,
SimpleDatabaseLock,
LockAcquisitionError,
get_training_lock
)
from .retry import (
RetryStrategy,
RetryError,
retry_async,
with_retry,
retry_with_timeout,
AdaptiveRetryStrategy,
TimeoutRetryStrategy,
HTTP_RETRY_STRATEGY,
DATABASE_RETRY_STRATEGY,
EXTERNAL_SERVICE_RETRY_STRATEGY
)
__all__ = [
# Timezone utilities
'ensure_timezone_aware',
'ensure_timezone_naive',
'normalize_datetime_to_utc',
'normalize_dataframe_datetime_column',
'prepare_prophet_datetime',
'safe_datetime_comparison',
'get_current_utc',
'convert_timestamp_to_datetime',
# Circuit breaker
'CircuitBreaker',
'CircuitBreakerError',
'CircuitState',
'circuit_breaker_registry',
# File utilities
'calculate_file_checksum',
'verify_file_checksum',
'get_file_size',
'ensure_directory_exists',
'safe_file_delete',
'get_file_metadata',
'ChecksummedFile',
# Distributed locking
'DatabaseLock',
'SimpleDatabaseLock',
'LockAcquisitionError',
'get_training_lock',
# Retry mechanisms
'RetryStrategy',
'RetryError',
'retry_async',
'with_retry',
'retry_with_timeout',
'AdaptiveRetryStrategy',
'TimeoutRetryStrategy',
'HTTP_RETRY_STRATEGY',
'DATABASE_RETRY_STRATEGY',
'EXTERNAL_SERVICE_RETRY_STRATEGY'
]

View File

@@ -0,0 +1,198 @@
"""
Circuit Breaker Pattern Implementation
Protects against cascading failures from external service calls
"""
import asyncio
import time
from enum import Enum
from typing import Callable, Any, Optional
import logging
from functools import wraps
logger = logging.getLogger(__name__)
class CircuitState(Enum):
"""Circuit breaker states"""
CLOSED = "closed" # Normal operation
OPEN = "open" # Circuit is open, rejecting requests
HALF_OPEN = "half_open" # Testing if service recovered
class CircuitBreakerError(Exception):
"""Raised when circuit breaker is open"""
pass
class CircuitBreaker:
"""
Circuit breaker to prevent cascading failures.
States:
- CLOSED: Normal operation, requests pass through
- OPEN: Too many failures, rejecting all requests
- HALF_OPEN: Testing if service recovered, allowing limited requests
"""
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: float = 60.0,
expected_exception: type = Exception,
name: str = "circuit_breaker"
):
"""
Initialize circuit breaker.
Args:
failure_threshold: Number of failures before opening circuit
recovery_timeout: Seconds to wait before attempting recovery
expected_exception: Exception type to catch (others will pass through)
name: Name for logging purposes
"""
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.expected_exception = expected_exception
self.name = name
self.failure_count = 0
self.last_failure_time: Optional[float] = None
self.state = CircuitState.CLOSED
def _record_success(self):
"""Record successful call"""
self.failure_count = 0
self.last_failure_time = None
if self.state == CircuitState.HALF_OPEN:
logger.info(f"Circuit breaker '{self.name}' recovered, closing circuit")
self.state = CircuitState.CLOSED
def _record_failure(self):
"""Record failed call"""
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold:
if self.state != CircuitState.OPEN:
logger.warning(
f"Circuit breaker '{self.name}' opened after {self.failure_count} failures"
)
self.state = CircuitState.OPEN
def _should_attempt_reset(self) -> bool:
"""Check if we should attempt to reset circuit"""
return (
self.state == CircuitState.OPEN
and self.last_failure_time is not None
and time.time() - self.last_failure_time >= self.recovery_timeout
)
async def call(self, func: Callable, *args, **kwargs) -> Any:
"""
Execute function with circuit breaker protection.
Args:
func: Async function to execute
*args: Positional arguments for func
**kwargs: Keyword arguments for func
Returns:
Result from func
Raises:
CircuitBreakerError: If circuit is open
Exception: Original exception if not expected_exception type
"""
# Check if circuit is open
if self.state == CircuitState.OPEN:
if self._should_attempt_reset():
logger.info(f"Circuit breaker '{self.name}' attempting recovery (half-open)")
self.state = CircuitState.HALF_OPEN
else:
raise CircuitBreakerError(
f"Circuit breaker '{self.name}' is open. "
f"Service unavailable for {self.recovery_timeout}s after {self.failure_count} failures."
)
try:
# Execute the function
result = await func(*args, **kwargs)
self._record_success()
return result
except self.expected_exception as e:
self._record_failure()
logger.error(
f"Circuit breaker '{self.name}' caught failure",
error=str(e),
failure_count=self.failure_count,
state=self.state.value
)
raise
def __call__(self, func: Callable) -> Callable:
"""Decorator interface for circuit breaker"""
@wraps(func)
async def wrapper(*args, **kwargs):
return await self.call(func, *args, **kwargs)
return wrapper
def get_state(self) -> dict:
"""Get current circuit breaker state for monitoring"""
return {
"name": self.name,
"state": self.state.value,
"failure_count": self.failure_count,
"failure_threshold": self.failure_threshold,
"last_failure_time": self.last_failure_time,
"recovery_timeout": self.recovery_timeout
}
class CircuitBreakerRegistry:
"""Registry to manage multiple circuit breakers"""
def __init__(self):
self._breakers: dict[str, CircuitBreaker] = {}
def get_or_create(
self,
name: str,
failure_threshold: int = 5,
recovery_timeout: float = 60.0,
expected_exception: type = Exception
) -> CircuitBreaker:
"""Get existing circuit breaker or create new one"""
if name not in self._breakers:
self._breakers[name] = CircuitBreaker(
failure_threshold=failure_threshold,
recovery_timeout=recovery_timeout,
expected_exception=expected_exception,
name=name
)
return self._breakers[name]
def get(self, name: str) -> Optional[CircuitBreaker]:
"""Get circuit breaker by name"""
return self._breakers.get(name)
def get_all_states(self) -> dict:
"""Get states of all circuit breakers"""
return {
name: breaker.get_state()
for name, breaker in self._breakers.items()
}
def reset(self, name: str):
"""Manually reset a circuit breaker"""
if name in self._breakers:
breaker = self._breakers[name]
breaker.failure_count = 0
breaker.last_failure_time = None
breaker.state = CircuitState.CLOSED
logger.info(f"Circuit breaker '{name}' manually reset")
# Global registry instance
circuit_breaker_registry = CircuitBreakerRegistry()

View File

@@ -0,0 +1,250 @@
"""
Distributed Locking Mechanisms
Prevents concurrent training jobs for the same product
HORIZONTAL SCALING FIX:
- Uses SHA256 for stable hash across all Python processes/pods
- Python's built-in hash() varies between processes due to hash randomization (Python 3.3+)
- This ensures all pods compute the same lock ID for the same lock name
"""
import asyncio
import time
import hashlib
from typing import Optional
import logging
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timezone, timedelta
logger = logging.getLogger(__name__)
class LockAcquisitionError(Exception):
"""Raised when lock cannot be acquired"""
pass
class DatabaseLock:
"""
Database-based distributed lock using PostgreSQL advisory locks.
Works across multiple service instances.
"""
def __init__(self, lock_name: str, timeout: float = 30.0):
"""
Initialize database lock.
Args:
lock_name: Unique identifier for the lock
timeout: Maximum seconds to wait for lock acquisition
"""
self.lock_name = lock_name
self.timeout = timeout
self.lock_id = self._hash_lock_name(lock_name)
def _hash_lock_name(self, name: str) -> int:
"""
Convert lock name to integer ID for PostgreSQL advisory lock.
CRITICAL: Uses SHA256 for stable hash across all Python processes/pods.
Python's built-in hash() varies between processes due to hash randomization
(PYTHONHASHSEED, enabled by default since Python 3.3), which would cause
different pods to compute different lock IDs for the same lock name,
defeating the purpose of distributed locking.
"""
# Use SHA256 for stable, cross-process hash
hash_bytes = hashlib.sha256(name.encode('utf-8')).digest()
# Take first 4 bytes and convert to positive 31-bit integer
# (PostgreSQL advisory locks use bigint, but we use 31-bit for safety)
return int.from_bytes(hash_bytes[:4], 'big') % (2**31)
@asynccontextmanager
async def acquire(self, session: AsyncSession):
"""
Acquire distributed lock as async context manager.
Args:
session: Database session for lock operations
Raises:
LockAcquisitionError: If lock cannot be acquired within timeout
"""
acquired = False
start_time = time.time()
try:
# Try to acquire lock with timeout
while time.time() - start_time < self.timeout:
# Try non-blocking lock acquisition
result = await session.execute(
text("SELECT pg_try_advisory_lock(:lock_id)"),
{"lock_id": self.lock_id}
)
acquired = result.scalar()
if acquired:
logger.info(f"Acquired lock: {self.lock_name} (id={self.lock_id})")
break
# Wait a bit before retrying
await asyncio.sleep(0.1)
if not acquired:
raise LockAcquisitionError(
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
)
yield
finally:
if acquired:
# Release lock
await session.execute(
text("SELECT pg_advisory_unlock(:lock_id)"),
{"lock_id": self.lock_id}
)
logger.info(f"Released lock: {self.lock_name} (id={self.lock_id})")
class SimpleDatabaseLock:
"""
Simple table-based distributed lock.
Alternative to advisory locks, uses a dedicated locks table.
"""
def __init__(self, lock_name: str, timeout: float = 30.0, ttl: float = 300.0):
"""
Initialize simple database lock.
Args:
lock_name: Unique identifier for the lock
timeout: Maximum seconds to wait for lock acquisition
ttl: Time-to-live for stale lock cleanup (seconds)
"""
self.lock_name = lock_name
self.timeout = timeout
self.ttl = ttl
async def _ensure_lock_table(self, session: AsyncSession):
"""Ensure locks table exists"""
create_table_sql = """
CREATE TABLE IF NOT EXISTS distributed_locks (
lock_name VARCHAR(255) PRIMARY KEY,
acquired_at TIMESTAMP WITH TIME ZONE NOT NULL,
acquired_by VARCHAR(255),
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
)
"""
await session.execute(text(create_table_sql))
await session.commit()
async def _cleanup_stale_locks(self, session: AsyncSession):
"""Remove expired locks"""
cleanup_sql = """
DELETE FROM distributed_locks
WHERE expires_at < :now
"""
await session.execute(
text(cleanup_sql),
{"now": datetime.now(timezone.utc)}
)
await session.commit()
@asynccontextmanager
async def acquire(self, session: AsyncSession, owner: str = "training-service"):
"""
Acquire simple database lock.
Args:
session: Database session
owner: Identifier for lock owner
Raises:
LockAcquisitionError: If lock cannot be acquired
"""
await self._ensure_lock_table(session)
await self._cleanup_stale_locks(session)
acquired = False
start_time = time.time()
try:
# Try to acquire lock
while time.time() - start_time < self.timeout:
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=self.ttl)
try:
# Try to insert lock record
insert_sql = """
INSERT INTO distributed_locks (lock_name, acquired_at, acquired_by, expires_at)
VALUES (:lock_name, :acquired_at, :acquired_by, :expires_at)
ON CONFLICT (lock_name) DO NOTHING
RETURNING lock_name
"""
result = await session.execute(
text(insert_sql),
{
"lock_name": self.lock_name,
"acquired_at": now,
"acquired_by": owner,
"expires_at": expires_at
}
)
await session.commit()
if result.rowcount > 0:
acquired = True
logger.info(f"Acquired simple lock: {self.lock_name}")
break
except Exception as e:
logger.debug(f"Lock acquisition attempt failed: {e}")
await session.rollback()
# Wait before retrying
await asyncio.sleep(0.5)
if not acquired:
raise LockAcquisitionError(
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
)
yield
finally:
if acquired:
# Release lock
delete_sql = """
DELETE FROM distributed_locks
WHERE lock_name = :lock_name
"""
await session.execute(
text(delete_sql),
{"lock_name": self.lock_name}
)
await session.commit()
logger.info(f"Released simple lock: {self.lock_name}")
def get_training_lock(tenant_id: str, product_id: str, use_advisory: bool = True) -> DatabaseLock:
"""
Get distributed lock for training a specific product.
Args:
tenant_id: Tenant identifier
product_id: Product identifier
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
Returns:
Lock instance
"""
lock_name = f"training:{tenant_id}:{product_id}"
if use_advisory:
return DatabaseLock(lock_name, timeout=60.0)
else:
return SimpleDatabaseLock(lock_name, timeout=60.0, ttl=600.0)

View File

@@ -0,0 +1,216 @@
"""
File Utility Functions
Utilities for secure file operations including checksum verification
"""
import hashlib
import os
from pathlib import Path
from typing import Optional
import logging
logger = logging.getLogger(__name__)
def calculate_file_checksum(file_path: str, algorithm: str = "sha256") -> str:
"""
Calculate checksum of a file.
Args:
file_path: Path to file
algorithm: Hash algorithm (sha256, md5, etc.)
Returns:
Hexadecimal checksum string
Raises:
FileNotFoundError: If file doesn't exist
ValueError: If algorithm not supported
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
hash_func = hashlib.new(algorithm)
except ValueError:
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
# Read file in chunks to handle large files efficiently
with open(file_path, 'rb') as f:
while chunk := f.read(8192):
hash_func.update(chunk)
return hash_func.hexdigest()
def verify_file_checksum(file_path: str, expected_checksum: str, algorithm: str = "sha256") -> bool:
"""
Verify file matches expected checksum.
Args:
file_path: Path to file
expected_checksum: Expected checksum value
algorithm: Hash algorithm used
Returns:
True if checksum matches, False otherwise
"""
try:
actual_checksum = calculate_file_checksum(file_path, algorithm)
matches = actual_checksum == expected_checksum
if matches:
logger.debug(f"Checksum verified for {file_path}")
else:
logger.warning(
f"Checksum mismatch for {file_path}",
expected=expected_checksum,
actual=actual_checksum
)
return matches
except Exception as e:
logger.error(f"Error verifying checksum for {file_path}: {e}")
return False
def get_file_size(file_path: str) -> int:
"""
Get file size in bytes.
Args:
file_path: Path to file
Returns:
File size in bytes
Raises:
FileNotFoundError: If file doesn't exist
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
return os.path.getsize(file_path)
def ensure_directory_exists(directory: str) -> Path:
"""
Ensure directory exists, create if necessary.
Args:
directory: Directory path
Returns:
Path object for directory
"""
path = Path(directory)
path.mkdir(parents=True, exist_ok=True)
return path
def safe_file_delete(file_path: str) -> bool:
"""
Safely delete a file, logging any errors.
Args:
file_path: Path to file
Returns:
True if deleted successfully, False otherwise
"""
try:
if os.path.exists(file_path):
os.remove(file_path)
logger.info(f"Deleted file: {file_path}")
return True
else:
logger.warning(f"File not found for deletion: {file_path}")
return False
except Exception as e:
logger.error(f"Error deleting file {file_path}: {e}")
return False
def get_file_metadata(file_path: str) -> dict:
"""
Get comprehensive file metadata.
Args:
file_path: Path to file
Returns:
Dictionary with file metadata
Raises:
FileNotFoundError: If file doesn't exist
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
stat = os.stat(file_path)
return {
"path": file_path,
"size_bytes": stat.st_size,
"created_at": stat.st_ctime,
"modified_at": stat.st_mtime,
"accessed_at": stat.st_atime,
"is_file": os.path.isfile(file_path),
"is_dir": os.path.isdir(file_path),
"exists": True
}
class ChecksummedFile:
"""
Context manager for working with checksummed files.
Automatically calculates and stores checksum when file is written.
"""
def __init__(self, file_path: str, checksum_path: Optional[str] = None, algorithm: str = "sha256"):
"""
Initialize checksummed file handler.
Args:
file_path: Path to the file
checksum_path: Path to store checksum (default: file_path + '.checksum')
algorithm: Hash algorithm to use
"""
self.file_path = file_path
self.checksum_path = checksum_path or f"{file_path}.checksum"
self.algorithm = algorithm
self.checksum: Optional[str] = None
def calculate_and_save_checksum(self) -> str:
"""Calculate checksum and save to file"""
self.checksum = calculate_file_checksum(self.file_path, self.algorithm)
with open(self.checksum_path, 'w') as f:
f.write(f"{self.checksum} {os.path.basename(self.file_path)}\n")
logger.info(f"Saved checksum for {self.file_path}: {self.checksum}")
return self.checksum
def load_and_verify_checksum(self) -> bool:
"""Load expected checksum and verify file"""
try:
with open(self.checksum_path, 'r') as f:
expected_checksum = f.read().strip().split()[0]
return verify_file_checksum(self.file_path, expected_checksum, self.algorithm)
except FileNotFoundError:
logger.warning(f"Checksum file not found: {self.checksum_path}")
return False
except Exception as e:
logger.error(f"Error loading checksum: {e}")
return False
def get_stored_checksum(self) -> Optional[str]:
"""Get checksum from stored file"""
try:
with open(self.checksum_path, 'r') as f:
return f.read().strip().split()[0]
except FileNotFoundError:
return None

View File

@@ -0,0 +1,270 @@
"""
ML-Specific DateTime Utilities
DateTime utilities for machine learning operations, specifically for:
- Prophet forecasting model (requires timezone-naive datetimes)
- Pandas DataFrame datetime operations
- Time series data processing
"""
from datetime import datetime, timezone
from typing import Union
import pandas as pd
import logging
logger = logging.getLogger(__name__)
def ensure_timezone_aware(dt: datetime, default_tz=timezone.utc) -> datetime:
"""
Ensure a datetime is timezone-aware.
Args:
dt: Datetime to check
default_tz: Timezone to apply if datetime is naive (default: UTC)
Returns:
Timezone-aware datetime
"""
if dt is None:
return None
if dt.tzinfo is None:
return dt.replace(tzinfo=default_tz)
return dt
def ensure_timezone_naive(dt: datetime) -> datetime:
"""
Remove timezone information from a datetime.
Args:
dt: Datetime to process
Returns:
Timezone-naive datetime
"""
if dt is None:
return None
if dt.tzinfo is not None:
return dt.replace(tzinfo=None)
return dt
def normalize_datetime_to_utc(dt: Union[datetime, pd.Timestamp]) -> datetime:
"""
Normalize any datetime to UTC timezone-aware datetime.
Args:
dt: Datetime or pandas Timestamp to normalize
Returns:
UTC timezone-aware datetime
"""
if dt is None:
return None
if isinstance(dt, pd.Timestamp):
dt = dt.to_pydatetime()
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)
def normalize_dataframe_datetime_column(
df: pd.DataFrame,
column: str,
target_format: str = 'naive'
) -> pd.DataFrame:
"""
Normalize a datetime column in a dataframe to consistent format.
Args:
df: DataFrame to process
column: Name of datetime column
target_format: 'naive' or 'aware' (UTC)
Returns:
DataFrame with normalized datetime column
"""
if column not in df.columns:
logger.warning(f"Column {column} not found in dataframe")
return df
df[column] = pd.to_datetime(df[column])
if target_format == 'naive':
if df[column].dt.tz is not None:
df[column] = df[column].dt.tz_localize(None)
elif target_format == 'aware':
if df[column].dt.tz is None:
df[column] = df[column].dt.tz_localize(timezone.utc)
else:
df[column] = df[column].dt.tz_convert(timezone.utc)
else:
raise ValueError(f"Invalid target_format: {target_format}. Must be 'naive' or 'aware'")
return df
def prepare_prophet_datetime(df: pd.DataFrame, datetime_col: str = 'ds') -> pd.DataFrame:
"""
Prepare datetime column for Prophet (requires timezone-naive datetimes).
Args:
df: DataFrame with datetime column
datetime_col: Name of datetime column (default: 'ds')
Returns:
DataFrame with Prophet-compatible datetime column
"""
df = df.copy()
df = normalize_dataframe_datetime_column(df, datetime_col, target_format='naive')
return df
def safe_datetime_comparison(dt1: datetime, dt2: datetime) -> int:
"""
Safely compare two datetimes, handling timezone mismatches.
Args:
dt1: First datetime
dt2: Second datetime
Returns:
-1 if dt1 < dt2, 0 if equal, 1 if dt1 > dt2
"""
dt1_utc = normalize_datetime_to_utc(dt1)
dt2_utc = normalize_datetime_to_utc(dt2)
if dt1_utc < dt2_utc:
return -1
elif dt1_utc > dt2_utc:
return 1
else:
return 0
def get_current_utc() -> datetime:
"""
Get current datetime in UTC with timezone awareness.
Returns:
Current UTC datetime
"""
return datetime.now(timezone.utc)
def convert_timestamp_to_datetime(timestamp: Union[int, float, str]) -> datetime:
"""
Convert various timestamp formats to datetime.
Args:
timestamp: Unix timestamp (seconds or milliseconds) or ISO string
Returns:
UTC timezone-aware datetime
"""
if isinstance(timestamp, str):
dt = pd.to_datetime(timestamp)
return normalize_datetime_to_utc(dt)
if timestamp > 1e10:
timestamp = timestamp / 1000
dt = datetime.fromtimestamp(timestamp, tz=timezone.utc)
return dt
def align_dataframe_dates(
dfs: list[pd.DataFrame],
date_column: str = 'ds',
method: str = 'inner'
) -> list[pd.DataFrame]:
"""
Align multiple dataframes to have the same date range.
Args:
dfs: List of DataFrames to align
date_column: Name of the date column
method: 'inner' (intersection) or 'outer' (union)
Returns:
List of aligned DataFrames
"""
if not dfs:
return []
if len(dfs) == 1:
return dfs
all_dates = None
for df in dfs:
if date_column not in df.columns:
continue
dates = set(pd.to_datetime(df[date_column]).dt.date)
if all_dates is None:
all_dates = dates
else:
if method == 'inner':
all_dates = all_dates.intersection(dates)
elif method == 'outer':
all_dates = all_dates.union(dates)
aligned_dfs = []
for df in dfs:
if date_column not in df.columns:
aligned_dfs.append(df)
continue
df = df.copy()
df[date_column] = pd.to_datetime(df[date_column])
df['_date_only'] = df[date_column].dt.date
df = df[df['_date_only'].isin(all_dates)]
df = df.drop('_date_only', axis=1)
aligned_dfs.append(df)
return aligned_dfs
def fill_missing_dates(
df: pd.DataFrame,
date_column: str = 'ds',
freq: str = 'D',
fill_value: float = 0.0
) -> pd.DataFrame:
"""
Fill missing dates in a DataFrame with a specified frequency.
Args:
df: DataFrame with date column
date_column: Name of the date column
freq: Pandas frequency string ('D' for daily, 'H' for hourly, etc.)
fill_value: Value to fill for missing dates
Returns:
DataFrame with filled dates
"""
df = df.copy()
df[date_column] = pd.to_datetime(df[date_column])
df = df.set_index(date_column)
full_range = pd.date_range(
start=df.index.min(),
end=df.index.max(),
freq=freq
)
df = df.reindex(full_range, fill_value=fill_value)
df = df.reset_index()
df = df.rename(columns={'index': date_column})
return df

View File

@@ -0,0 +1,316 @@
"""
Retry Mechanism with Exponential Backoff
Handles transient failures with intelligent retry strategies
"""
import asyncio
import time
import random
from typing import Callable, Any, Optional, Type, Tuple
from functools import wraps
import logging
logger = logging.getLogger(__name__)
class RetryError(Exception):
"""Raised when all retry attempts are exhausted"""
def __init__(self, message: str, attempts: int, last_exception: Exception):
super().__init__(message)
self.attempts = attempts
self.last_exception = last_exception
class RetryStrategy:
"""Base retry strategy"""
def __init__(
self,
max_attempts: int = 3,
initial_delay: float = 1.0,
max_delay: float = 60.0,
exponential_base: float = 2.0,
jitter: bool = True,
retriable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
):
"""
Initialize retry strategy.
Args:
max_attempts: Maximum number of retry attempts
initial_delay: Initial delay in seconds
max_delay: Maximum delay between retries
exponential_base: Base for exponential backoff
jitter: Add random jitter to prevent thundering herd
retriable_exceptions: Tuple of exception types to retry
"""
self.max_attempts = max_attempts
self.initial_delay = initial_delay
self.max_delay = max_delay
self.exponential_base = exponential_base
self.jitter = jitter
self.retriable_exceptions = retriable_exceptions
def calculate_delay(self, attempt: int) -> float:
"""Calculate delay for given attempt using exponential backoff"""
delay = min(
self.initial_delay * (self.exponential_base ** attempt),
self.max_delay
)
if self.jitter:
# Add random jitter (0-100% of delay)
delay = delay * (0.5 + random.random() * 0.5)
return delay
def is_retriable(self, exception: Exception) -> bool:
"""Check if exception should trigger retry"""
return isinstance(exception, self.retriable_exceptions)
async def retry_async(
func: Callable,
*args,
strategy: Optional[RetryStrategy] = None,
**kwargs
) -> Any:
"""
Retry async function with exponential backoff.
Args:
func: Async function to retry
*args: Positional arguments for func
strategy: Retry strategy (uses default if None)
**kwargs: Keyword arguments for func
Returns:
Result from func
Raises:
RetryError: When all attempts exhausted
"""
if strategy is None:
strategy = RetryStrategy()
last_exception = None
for attempt in range(strategy.max_attempts):
try:
result = await func(*args, **kwargs)
if attempt > 0:
logger.info(
f"Retry succeeded on attempt {attempt + 1}",
function=func.__name__,
attempt=attempt + 1
)
return result
except Exception as e:
last_exception = e
if not strategy.is_retriable(e):
logger.error(
f"Non-retriable exception occurred",
function=func.__name__,
exception=str(e)
)
raise
if attempt < strategy.max_attempts - 1:
delay = strategy.calculate_delay(attempt)
logger.warning(
f"Attempt {attempt + 1} failed, retrying in {delay:.2f}s",
function=func.__name__,
attempt=attempt + 1,
max_attempts=strategy.max_attempts,
exception=str(e)
)
await asyncio.sleep(delay)
else:
logger.error(
f"All {strategy.max_attempts} retry attempts exhausted",
function=func.__name__,
exception=str(e)
)
raise RetryError(
f"Failed after {strategy.max_attempts} attempts: {str(last_exception)}",
attempts=strategy.max_attempts,
last_exception=last_exception
)
def with_retry(
max_attempts: int = 3,
initial_delay: float = 1.0,
max_delay: float = 60.0,
exponential_base: float = 2.0,
jitter: bool = True,
retriable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
):
"""
Decorator to add retry logic to async functions.
Example:
@with_retry(max_attempts=5, initial_delay=2.0)
async def fetch_data():
# Your code here
pass
"""
strategy = RetryStrategy(
max_attempts=max_attempts,
initial_delay=initial_delay,
max_delay=max_delay,
exponential_base=exponential_base,
jitter=jitter,
retriable_exceptions=retriable_exceptions
)
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
return await retry_async(func, *args, strategy=strategy, **kwargs)
return wrapper
return decorator
class AdaptiveRetryStrategy(RetryStrategy):
"""
Adaptive retry strategy that adjusts based on success/failure patterns.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.success_count = 0
self.failure_count = 0
self.consecutive_failures = 0
def calculate_delay(self, attempt: int) -> float:
"""Calculate delay with adaptation based on recent history"""
base_delay = super().calculate_delay(attempt)
# Increase delay if seeing consecutive failures
if self.consecutive_failures > 5:
multiplier = min(2.0, 1.0 + (self.consecutive_failures - 5) * 0.2)
base_delay *= multiplier
return min(base_delay, self.max_delay)
def record_success(self):
"""Record successful attempt"""
self.success_count += 1
self.consecutive_failures = 0
def record_failure(self):
"""Record failed attempt"""
self.failure_count += 1
self.consecutive_failures += 1
class TimeoutRetryStrategy(RetryStrategy):
"""
Retry strategy with overall timeout across all attempts.
"""
def __init__(self, *args, timeout: float = 300.0, **kwargs):
"""
Args:
timeout: Total timeout in seconds for all attempts
"""
super().__init__(*args, **kwargs)
self.timeout = timeout
self.start_time: Optional[float] = None
def should_retry(self, attempt: int) -> bool:
"""Check if should attempt another retry"""
if self.start_time is None:
self.start_time = time.time()
return True
elapsed = time.time() - self.start_time
return elapsed < self.timeout and attempt < self.max_attempts
async def retry_with_timeout(
func: Callable,
*args,
max_attempts: int = 3,
timeout: float = 300.0,
**kwargs
) -> Any:
"""
Retry with overall timeout.
Args:
func: Function to retry
max_attempts: Maximum attempts
timeout: Overall timeout in seconds
Returns:
Result from func
"""
strategy = TimeoutRetryStrategy(
max_attempts=max_attempts,
timeout=timeout
)
start_time = time.time()
strategy.start_time = start_time
last_exception = None
for attempt in range(strategy.max_attempts):
if time.time() - start_time >= timeout:
raise RetryError(
f"Timeout of {timeout}s exceeded",
attempts=attempt + 1,
last_exception=last_exception
)
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
if not strategy.is_retriable(e):
raise
if attempt < strategy.max_attempts - 1:
delay = strategy.calculate_delay(attempt)
await asyncio.sleep(delay)
raise RetryError(
f"Failed after {strategy.max_attempts} attempts",
attempts=strategy.max_attempts,
last_exception=last_exception
)
# Pre-configured strategies for common use cases
HTTP_RETRY_STRATEGY = RetryStrategy(
max_attempts=3,
initial_delay=1.0,
max_delay=10.0,
exponential_base=2.0,
jitter=True
)
DATABASE_RETRY_STRATEGY = RetryStrategy(
max_attempts=5,
initial_delay=0.5,
max_delay=5.0,
exponential_base=1.5,
jitter=True
)
EXTERNAL_SERVICE_RETRY_STRATEGY = RetryStrategy(
max_attempts=4,
initial_delay=2.0,
max_delay=30.0,
exponential_base=2.5,
jitter=True
)

View File

@@ -0,0 +1,340 @@
"""
Training Time Estimation Utilities
Provides intelligent time estimation for training jobs based on:
- Product count
- Historical performance data
- Current progress and throughput
"""
from typing import List, Optional
from datetime import datetime, timedelta, timezone
import structlog
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
logger = structlog.get_logger()
def calculate_initial_estimate(
total_products: int,
avg_training_time_per_product: float = 60.0, # seconds, default 1 min/product
data_analysis_overhead: float = 120.0, # seconds, data loading & analysis
finalization_overhead: float = 60.0, # seconds, saving models & cleanup
min_estimate_minutes: int = 5,
max_estimate_minutes: int = 60
) -> int:
"""
Calculate realistic initial time estimate for training job.
Formula:
total_time = data_analysis + (products * avg_time_per_product) + finalization
Args:
total_products: Number of products to train
avg_training_time_per_product: Average time per product in seconds
data_analysis_overhead: Time for data loading and analysis in seconds
finalization_overhead: Time for saving models and cleanup in seconds
min_estimate_minutes: Minimum estimate (prevents unrealistic low values)
max_estimate_minutes: Maximum estimate (prevents unrealistic high values)
Returns:
Estimated duration in minutes
Examples:
>>> calculate_initial_estimate(1)
4 # 120 + 60 + 60 = 240s = 4min
>>> calculate_initial_estimate(5)
8 # 120 + 300 + 60 = 480s = 8min
>>> calculate_initial_estimate(10)
13 # 120 + 600 + 60 = 780s = 13min
>>> calculate_initial_estimate(20)
23 # 120 + 1200 + 60 = 1380s = 23min
>>> calculate_initial_estimate(100)
60 # Capped at max (would be 103 min)
"""
# Calculate total estimated time in seconds
estimated_seconds = (
data_analysis_overhead +
(total_products * avg_training_time_per_product) +
finalization_overhead
)
# Convert to minutes, round up
estimated_minutes = int((estimated_seconds / 60) + 0.5)
# Apply min/max bounds
estimated_minutes = max(min_estimate_minutes, min(max_estimate_minutes, estimated_minutes))
logger.info(
"Calculated initial time estimate",
total_products=total_products,
estimated_seconds=estimated_seconds,
estimated_minutes=estimated_minutes,
avg_time_per_product=avg_training_time_per_product
)
return estimated_minutes
def calculate_estimated_completion_time(
estimated_duration_minutes: int,
start_time: Optional[datetime] = None
) -> datetime:
"""
Calculate estimated completion timestamp.
Args:
estimated_duration_minutes: Estimated duration in minutes
start_time: Job start time (defaults to now)
Returns:
Estimated completion datetime (timezone-aware UTC)
"""
if start_time is None:
start_time = datetime.now(timezone.utc)
completion_time = start_time + timedelta(minutes=estimated_duration_minutes)
return completion_time
def calculate_remaining_time_smart(
progress: int,
elapsed_time: float,
products_completed: int,
total_products: int,
recent_product_times: Optional[List[float]] = None,
max_remaining_seconds: int = 1800 # 30 minutes
) -> Optional[int]:
"""
Calculate remaining time using smart algorithm that considers:
- Current progress percentage
- Actual throughput (products completed / elapsed time)
- Recent performance (weighted moving average)
Args:
progress: Current progress percentage (0-100)
elapsed_time: Time elapsed since job start (seconds)
products_completed: Number of products completed
total_products: Total number of products
recent_product_times: List of recent product training times (seconds)
max_remaining_seconds: Maximum remaining time (safety cap)
Returns:
Estimated remaining time in seconds, or None if can't calculate
"""
# Job completed or not started
if progress >= 100 or progress <= 0:
return None
# Early stage (0-20%): Use weighted estimate
if progress <= 20:
# In data analysis phase - estimate based on remaining products
remaining_products = total_products - products_completed
if recent_product_times and len(recent_product_times) > 0:
# Use recent performance if available
avg_time_per_product = sum(recent_product_times) / len(recent_product_times)
else:
# Fallback to default
avg_time_per_product = 60.0 # 1 minute per product
# Estimate: remaining products * avg time + overhead
estimated_remaining = (remaining_products * avg_time_per_product) + 60.0 # +1 min overhead
logger.debug(
"Early stage estimation",
progress=progress,
remaining_products=remaining_products,
avg_time_per_product=avg_time_per_product,
estimated_remaining=estimated_remaining
)
# Mid/late stage (21-99%): Use actual throughput
else:
if products_completed > 0:
# Calculate actual time per product from current run
actual_time_per_product = elapsed_time / products_completed
remaining_products = total_products - products_completed
estimated_remaining = remaining_products * actual_time_per_product
logger.debug(
"Mid/late stage estimation",
progress=progress,
products_completed=products_completed,
total_products=total_products,
actual_time_per_product=actual_time_per_product,
estimated_remaining=estimated_remaining
)
else:
# Fallback to linear extrapolation
estimated_total = (elapsed_time / progress) * 100
estimated_remaining = estimated_total - elapsed_time
logger.debug(
"Fallback linear estimation",
progress=progress,
elapsed_time=elapsed_time,
estimated_remaining=estimated_remaining
)
# Apply safety cap
estimated_remaining = min(estimated_remaining, max_remaining_seconds)
return int(estimated_remaining)
def calculate_average_product_time(
products_completed: int,
elapsed_time: float,
min_products_threshold: int = 3
) -> Optional[float]:
"""
Calculate average time per product from current job progress.
Args:
products_completed: Number of products completed
elapsed_time: Time elapsed since job start (seconds)
min_products_threshold: Minimum products needed for reliable calculation
Returns:
Average time per product in seconds, or None if insufficient data
"""
if products_completed < min_products_threshold:
return None
avg_time = elapsed_time / products_completed
logger.debug(
"Calculated average product time",
products_completed=products_completed,
elapsed_time=elapsed_time,
avg_time=avg_time
)
return avg_time
def format_time_remaining(seconds: int) -> str:
"""
Format remaining time in human-readable format.
Args:
seconds: Time in seconds
Returns:
Formatted string (e.g., "5 minutes", "1 hour 23 minutes")
Examples:
>>> format_time_remaining(45)
"45 seconds"
>>> format_time_remaining(180)
"3 minutes"
>>> format_time_remaining(5400)
"1 hour 30 minutes"
"""
if seconds < 60:
return f"{seconds} seconds"
minutes = seconds // 60
remaining_seconds = seconds % 60
if minutes < 60:
if remaining_seconds > 0:
return f"{minutes} minutes {remaining_seconds} seconds"
return f"{minutes} minutes"
hours = minutes // 60
remaining_minutes = minutes % 60
if remaining_minutes > 0:
return f"{hours} hour{'s' if hours > 1 else ''} {remaining_minutes} minutes"
return f"{hours} hour{'s' if hours > 1 else ''}"
async def get_historical_average_estimate(
db_session: AsyncSession,
tenant_id: str,
lookback_days: int = 30,
limit: int = 10
) -> Optional[float]:
"""
Get historical average training time per product for a tenant.
This function queries the TrainingPerformanceMetrics table to get
recent historical data and calculate an average.
Args:
db_session: Async database session
tenant_id: Tenant UUID
lookback_days: How many days back to look
limit: Maximum number of historical records to consider
Returns:
Average time per product in seconds, or None if no historical data
"""
try:
from app.models.training import TrainingPerformanceMetrics
from datetime import timedelta
cutoff = datetime.now(timezone.utc) - timedelta(days=lookback_days)
# Query recent training performance metrics using SQLAlchemy 2.0 async pattern
query = (
select(TrainingPerformanceMetrics)
.where(
TrainingPerformanceMetrics.tenant_id == tenant_id,
TrainingPerformanceMetrics.completed_at >= cutoff
)
.order_by(TrainingPerformanceMetrics.completed_at.desc())
.limit(limit)
)
result = await db_session.execute(query)
metrics = result.scalars().all()
if not metrics:
logger.info(
"No historical training data found",
tenant_id=tenant_id,
lookback_days=lookback_days
)
return None
# Calculate weighted average (more recent = higher weight)
total_weight = 0
weighted_sum = 0
for i, metric in enumerate(metrics):
# Weight: newer records get higher weight
weight = limit - i
weighted_sum += metric.avg_time_per_product * weight
total_weight += weight
if total_weight == 0:
return None
weighted_avg = weighted_sum / total_weight
logger.info(
"Calculated historical average",
tenant_id=tenant_id,
records_used=len(metrics),
weighted_avg=weighted_avg
)
return weighted_avg
except Exception as e:
logger.error(
"Error getting historical average",
tenant_id=tenant_id,
error=str(e)
)
return None