REFACTOR external service and improve websocket training

This commit is contained in:
Urtzi Alfaro
2025-10-09 14:11:02 +02:00
parent 7c72f83c51
commit 3c689b4f98
111 changed files with 13289 additions and 2374 deletions

View File

@@ -0,0 +1,92 @@
"""
Training Service Utilities
"""
from .timezone_utils import (
ensure_timezone_aware,
ensure_timezone_naive,
normalize_datetime_to_utc,
normalize_dataframe_datetime_column,
prepare_prophet_datetime,
safe_datetime_comparison,
get_current_utc,
convert_timestamp_to_datetime
)
from .circuit_breaker import (
CircuitBreaker,
CircuitBreakerError,
CircuitState,
circuit_breaker_registry
)
from .file_utils import (
calculate_file_checksum,
verify_file_checksum,
get_file_size,
ensure_directory_exists,
safe_file_delete,
get_file_metadata,
ChecksummedFile
)
from .distributed_lock import (
DatabaseLock,
SimpleDatabaseLock,
LockAcquisitionError,
get_training_lock
)
from .retry import (
RetryStrategy,
RetryError,
retry_async,
with_retry,
retry_with_timeout,
AdaptiveRetryStrategy,
TimeoutRetryStrategy,
HTTP_RETRY_STRATEGY,
DATABASE_RETRY_STRATEGY,
EXTERNAL_SERVICE_RETRY_STRATEGY
)
__all__ = [
# Timezone utilities
'ensure_timezone_aware',
'ensure_timezone_naive',
'normalize_datetime_to_utc',
'normalize_dataframe_datetime_column',
'prepare_prophet_datetime',
'safe_datetime_comparison',
'get_current_utc',
'convert_timestamp_to_datetime',
# Circuit breaker
'CircuitBreaker',
'CircuitBreakerError',
'CircuitState',
'circuit_breaker_registry',
# File utilities
'calculate_file_checksum',
'verify_file_checksum',
'get_file_size',
'ensure_directory_exists',
'safe_file_delete',
'get_file_metadata',
'ChecksummedFile',
# Distributed locking
'DatabaseLock',
'SimpleDatabaseLock',
'LockAcquisitionError',
'get_training_lock',
# Retry mechanisms
'RetryStrategy',
'RetryError',
'retry_async',
'with_retry',
'retry_with_timeout',
'AdaptiveRetryStrategy',
'TimeoutRetryStrategy',
'HTTP_RETRY_STRATEGY',
'DATABASE_RETRY_STRATEGY',
'EXTERNAL_SERVICE_RETRY_STRATEGY'
]

View File

@@ -0,0 +1,198 @@
"""
Circuit Breaker Pattern Implementation
Protects against cascading failures from external service calls
"""
import asyncio
import time
from enum import Enum
from typing import Callable, Any, Optional
import logging
from functools import wraps
logger = logging.getLogger(__name__)
class CircuitState(Enum):
"""Circuit breaker states"""
CLOSED = "closed" # Normal operation
OPEN = "open" # Circuit is open, rejecting requests
HALF_OPEN = "half_open" # Testing if service recovered
class CircuitBreakerError(Exception):
"""Raised when circuit breaker is open"""
pass
class CircuitBreaker:
"""
Circuit breaker to prevent cascading failures.
States:
- CLOSED: Normal operation, requests pass through
- OPEN: Too many failures, rejecting all requests
- HALF_OPEN: Testing if service recovered, allowing limited requests
"""
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: float = 60.0,
expected_exception: type = Exception,
name: str = "circuit_breaker"
):
"""
Initialize circuit breaker.
Args:
failure_threshold: Number of failures before opening circuit
recovery_timeout: Seconds to wait before attempting recovery
expected_exception: Exception type to catch (others will pass through)
name: Name for logging purposes
"""
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.expected_exception = expected_exception
self.name = name
self.failure_count = 0
self.last_failure_time: Optional[float] = None
self.state = CircuitState.CLOSED
def _record_success(self):
"""Record successful call"""
self.failure_count = 0
self.last_failure_time = None
if self.state == CircuitState.HALF_OPEN:
logger.info(f"Circuit breaker '{self.name}' recovered, closing circuit")
self.state = CircuitState.CLOSED
def _record_failure(self):
"""Record failed call"""
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold:
if self.state != CircuitState.OPEN:
logger.warning(
f"Circuit breaker '{self.name}' opened after {self.failure_count} failures"
)
self.state = CircuitState.OPEN
def _should_attempt_reset(self) -> bool:
"""Check if we should attempt to reset circuit"""
return (
self.state == CircuitState.OPEN
and self.last_failure_time is not None
and time.time() - self.last_failure_time >= self.recovery_timeout
)
async def call(self, func: Callable, *args, **kwargs) -> Any:
"""
Execute function with circuit breaker protection.
Args:
func: Async function to execute
*args: Positional arguments for func
**kwargs: Keyword arguments for func
Returns:
Result from func
Raises:
CircuitBreakerError: If circuit is open
Exception: Original exception if not expected_exception type
"""
# Check if circuit is open
if self.state == CircuitState.OPEN:
if self._should_attempt_reset():
logger.info(f"Circuit breaker '{self.name}' attempting recovery (half-open)")
self.state = CircuitState.HALF_OPEN
else:
raise CircuitBreakerError(
f"Circuit breaker '{self.name}' is open. "
f"Service unavailable for {self.recovery_timeout}s after {self.failure_count} failures."
)
try:
# Execute the function
result = await func(*args, **kwargs)
self._record_success()
return result
except self.expected_exception as e:
self._record_failure()
logger.error(
f"Circuit breaker '{self.name}' caught failure",
error=str(e),
failure_count=self.failure_count,
state=self.state.value
)
raise
def __call__(self, func: Callable) -> Callable:
"""Decorator interface for circuit breaker"""
@wraps(func)
async def wrapper(*args, **kwargs):
return await self.call(func, *args, **kwargs)
return wrapper
def get_state(self) -> dict:
"""Get current circuit breaker state for monitoring"""
return {
"name": self.name,
"state": self.state.value,
"failure_count": self.failure_count,
"failure_threshold": self.failure_threshold,
"last_failure_time": self.last_failure_time,
"recovery_timeout": self.recovery_timeout
}
class CircuitBreakerRegistry:
"""Registry to manage multiple circuit breakers"""
def __init__(self):
self._breakers: dict[str, CircuitBreaker] = {}
def get_or_create(
self,
name: str,
failure_threshold: int = 5,
recovery_timeout: float = 60.0,
expected_exception: type = Exception
) -> CircuitBreaker:
"""Get existing circuit breaker or create new one"""
if name not in self._breakers:
self._breakers[name] = CircuitBreaker(
failure_threshold=failure_threshold,
recovery_timeout=recovery_timeout,
expected_exception=expected_exception,
name=name
)
return self._breakers[name]
def get(self, name: str) -> Optional[CircuitBreaker]:
"""Get circuit breaker by name"""
return self._breakers.get(name)
def get_all_states(self) -> dict:
"""Get states of all circuit breakers"""
return {
name: breaker.get_state()
for name, breaker in self._breakers.items()
}
def reset(self, name: str):
"""Manually reset a circuit breaker"""
if name in self._breakers:
breaker = self._breakers[name]
breaker.failure_count = 0
breaker.last_failure_time = None
breaker.state = CircuitState.CLOSED
logger.info(f"Circuit breaker '{name}' manually reset")
# Global registry instance
circuit_breaker_registry = CircuitBreakerRegistry()

View File

@@ -0,0 +1,233 @@
"""
Distributed Locking Mechanisms
Prevents concurrent training jobs for the same product
"""
import asyncio
import time
from typing import Optional
import logging
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timezone, timedelta
logger = logging.getLogger(__name__)
class LockAcquisitionError(Exception):
"""Raised when lock cannot be acquired"""
pass
class DatabaseLock:
"""
Database-based distributed lock using PostgreSQL advisory locks.
Works across multiple service instances.
"""
def __init__(self, lock_name: str, timeout: float = 30.0):
"""
Initialize database lock.
Args:
lock_name: Unique identifier for the lock
timeout: Maximum seconds to wait for lock acquisition
"""
self.lock_name = lock_name
self.timeout = timeout
self.lock_id = self._hash_lock_name(lock_name)
def _hash_lock_name(self, name: str) -> int:
"""Convert lock name to integer ID for PostgreSQL advisory lock"""
# Use hash and modulo to get a positive 32-bit integer
return abs(hash(name)) % (2**31)
@asynccontextmanager
async def acquire(self, session: AsyncSession):
"""
Acquire distributed lock as async context manager.
Args:
session: Database session for lock operations
Raises:
LockAcquisitionError: If lock cannot be acquired within timeout
"""
acquired = False
start_time = time.time()
try:
# Try to acquire lock with timeout
while time.time() - start_time < self.timeout:
# Try non-blocking lock acquisition
result = await session.execute(
text("SELECT pg_try_advisory_lock(:lock_id)"),
{"lock_id": self.lock_id}
)
acquired = result.scalar()
if acquired:
logger.info(f"Acquired lock: {self.lock_name} (id={self.lock_id})")
break
# Wait a bit before retrying
await asyncio.sleep(0.1)
if not acquired:
raise LockAcquisitionError(
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
)
yield
finally:
if acquired:
# Release lock
await session.execute(
text("SELECT pg_advisory_unlock(:lock_id)"),
{"lock_id": self.lock_id}
)
logger.info(f"Released lock: {self.lock_name} (id={self.lock_id})")
class SimpleDatabaseLock:
"""
Simple table-based distributed lock.
Alternative to advisory locks, uses a dedicated locks table.
"""
def __init__(self, lock_name: str, timeout: float = 30.0, ttl: float = 300.0):
"""
Initialize simple database lock.
Args:
lock_name: Unique identifier for the lock
timeout: Maximum seconds to wait for lock acquisition
ttl: Time-to-live for stale lock cleanup (seconds)
"""
self.lock_name = lock_name
self.timeout = timeout
self.ttl = ttl
async def _ensure_lock_table(self, session: AsyncSession):
"""Ensure locks table exists"""
create_table_sql = """
CREATE TABLE IF NOT EXISTS distributed_locks (
lock_name VARCHAR(255) PRIMARY KEY,
acquired_at TIMESTAMP WITH TIME ZONE NOT NULL,
acquired_by VARCHAR(255),
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
)
"""
await session.execute(text(create_table_sql))
await session.commit()
async def _cleanup_stale_locks(self, session: AsyncSession):
"""Remove expired locks"""
cleanup_sql = """
DELETE FROM distributed_locks
WHERE expires_at < :now
"""
await session.execute(
text(cleanup_sql),
{"now": datetime.now(timezone.utc)}
)
await session.commit()
@asynccontextmanager
async def acquire(self, session: AsyncSession, owner: str = "training-service"):
"""
Acquire simple database lock.
Args:
session: Database session
owner: Identifier for lock owner
Raises:
LockAcquisitionError: If lock cannot be acquired
"""
await self._ensure_lock_table(session)
await self._cleanup_stale_locks(session)
acquired = False
start_time = time.time()
try:
# Try to acquire lock
while time.time() - start_time < self.timeout:
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=self.ttl)
try:
# Try to insert lock record
insert_sql = """
INSERT INTO distributed_locks (lock_name, acquired_at, acquired_by, expires_at)
VALUES (:lock_name, :acquired_at, :acquired_by, :expires_at)
ON CONFLICT (lock_name) DO NOTHING
RETURNING lock_name
"""
result = await session.execute(
text(insert_sql),
{
"lock_name": self.lock_name,
"acquired_at": now,
"acquired_by": owner,
"expires_at": expires_at
}
)
await session.commit()
if result.rowcount > 0:
acquired = True
logger.info(f"Acquired simple lock: {self.lock_name}")
break
except Exception as e:
logger.debug(f"Lock acquisition attempt failed: {e}")
await session.rollback()
# Wait before retrying
await asyncio.sleep(0.5)
if not acquired:
raise LockAcquisitionError(
f"Could not acquire lock '{self.lock_name}' within {self.timeout}s"
)
yield
finally:
if acquired:
# Release lock
delete_sql = """
DELETE FROM distributed_locks
WHERE lock_name = :lock_name
"""
await session.execute(
text(delete_sql),
{"lock_name": self.lock_name}
)
await session.commit()
logger.info(f"Released simple lock: {self.lock_name}")
def get_training_lock(tenant_id: str, product_id: str, use_advisory: bool = True) -> DatabaseLock:
"""
Get distributed lock for training a specific product.
Args:
tenant_id: Tenant identifier
product_id: Product identifier
use_advisory: Use PostgreSQL advisory locks (True) or table-based (False)
Returns:
Lock instance
"""
lock_name = f"training:{tenant_id}:{product_id}"
if use_advisory:
return DatabaseLock(lock_name, timeout=60.0)
else:
return SimpleDatabaseLock(lock_name, timeout=60.0, ttl=600.0)

View File

@@ -0,0 +1,216 @@
"""
File Utility Functions
Utilities for secure file operations including checksum verification
"""
import hashlib
import os
from pathlib import Path
from typing import Optional
import logging
logger = logging.getLogger(__name__)
def calculate_file_checksum(file_path: str, algorithm: str = "sha256") -> str:
"""
Calculate checksum of a file.
Args:
file_path: Path to file
algorithm: Hash algorithm (sha256, md5, etc.)
Returns:
Hexadecimal checksum string
Raises:
FileNotFoundError: If file doesn't exist
ValueError: If algorithm not supported
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
hash_func = hashlib.new(algorithm)
except ValueError:
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
# Read file in chunks to handle large files efficiently
with open(file_path, 'rb') as f:
while chunk := f.read(8192):
hash_func.update(chunk)
return hash_func.hexdigest()
def verify_file_checksum(file_path: str, expected_checksum: str, algorithm: str = "sha256") -> bool:
"""
Verify file matches expected checksum.
Args:
file_path: Path to file
expected_checksum: Expected checksum value
algorithm: Hash algorithm used
Returns:
True if checksum matches, False otherwise
"""
try:
actual_checksum = calculate_file_checksum(file_path, algorithm)
matches = actual_checksum == expected_checksum
if matches:
logger.debug(f"Checksum verified for {file_path}")
else:
logger.warning(
f"Checksum mismatch for {file_path}",
expected=expected_checksum,
actual=actual_checksum
)
return matches
except Exception as e:
logger.error(f"Error verifying checksum for {file_path}: {e}")
return False
def get_file_size(file_path: str) -> int:
"""
Get file size in bytes.
Args:
file_path: Path to file
Returns:
File size in bytes
Raises:
FileNotFoundError: If file doesn't exist
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
return os.path.getsize(file_path)
def ensure_directory_exists(directory: str) -> Path:
"""
Ensure directory exists, create if necessary.
Args:
directory: Directory path
Returns:
Path object for directory
"""
path = Path(directory)
path.mkdir(parents=True, exist_ok=True)
return path
def safe_file_delete(file_path: str) -> bool:
"""
Safely delete a file, logging any errors.
Args:
file_path: Path to file
Returns:
True if deleted successfully, False otherwise
"""
try:
if os.path.exists(file_path):
os.remove(file_path)
logger.info(f"Deleted file: {file_path}")
return True
else:
logger.warning(f"File not found for deletion: {file_path}")
return False
except Exception as e:
logger.error(f"Error deleting file {file_path}: {e}")
return False
def get_file_metadata(file_path: str) -> dict:
"""
Get comprehensive file metadata.
Args:
file_path: Path to file
Returns:
Dictionary with file metadata
Raises:
FileNotFoundError: If file doesn't exist
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
stat = os.stat(file_path)
return {
"path": file_path,
"size_bytes": stat.st_size,
"created_at": stat.st_ctime,
"modified_at": stat.st_mtime,
"accessed_at": stat.st_atime,
"is_file": os.path.isfile(file_path),
"is_dir": os.path.isdir(file_path),
"exists": True
}
class ChecksummedFile:
"""
Context manager for working with checksummed files.
Automatically calculates and stores checksum when file is written.
"""
def __init__(self, file_path: str, checksum_path: Optional[str] = None, algorithm: str = "sha256"):
"""
Initialize checksummed file handler.
Args:
file_path: Path to the file
checksum_path: Path to store checksum (default: file_path + '.checksum')
algorithm: Hash algorithm to use
"""
self.file_path = file_path
self.checksum_path = checksum_path or f"{file_path}.checksum"
self.algorithm = algorithm
self.checksum: Optional[str] = None
def calculate_and_save_checksum(self) -> str:
"""Calculate checksum and save to file"""
self.checksum = calculate_file_checksum(self.file_path, self.algorithm)
with open(self.checksum_path, 'w') as f:
f.write(f"{self.checksum} {os.path.basename(self.file_path)}\n")
logger.info(f"Saved checksum for {self.file_path}: {self.checksum}")
return self.checksum
def load_and_verify_checksum(self) -> bool:
"""Load expected checksum and verify file"""
try:
with open(self.checksum_path, 'r') as f:
expected_checksum = f.read().strip().split()[0]
return verify_file_checksum(self.file_path, expected_checksum, self.algorithm)
except FileNotFoundError:
logger.warning(f"Checksum file not found: {self.checksum_path}")
return False
except Exception as e:
logger.error(f"Error loading checksum: {e}")
return False
def get_stored_checksum(self) -> Optional[str]:
"""Get checksum from stored file"""
try:
with open(self.checksum_path, 'r') as f:
return f.read().strip().split()[0]
except FileNotFoundError:
return None

View File

@@ -0,0 +1,316 @@
"""
Retry Mechanism with Exponential Backoff
Handles transient failures with intelligent retry strategies
"""
import asyncio
import time
import random
from typing import Callable, Any, Optional, Type, Tuple
from functools import wraps
import logging
logger = logging.getLogger(__name__)
class RetryError(Exception):
"""Raised when all retry attempts are exhausted"""
def __init__(self, message: str, attempts: int, last_exception: Exception):
super().__init__(message)
self.attempts = attempts
self.last_exception = last_exception
class RetryStrategy:
"""Base retry strategy"""
def __init__(
self,
max_attempts: int = 3,
initial_delay: float = 1.0,
max_delay: float = 60.0,
exponential_base: float = 2.0,
jitter: bool = True,
retriable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
):
"""
Initialize retry strategy.
Args:
max_attempts: Maximum number of retry attempts
initial_delay: Initial delay in seconds
max_delay: Maximum delay between retries
exponential_base: Base for exponential backoff
jitter: Add random jitter to prevent thundering herd
retriable_exceptions: Tuple of exception types to retry
"""
self.max_attempts = max_attempts
self.initial_delay = initial_delay
self.max_delay = max_delay
self.exponential_base = exponential_base
self.jitter = jitter
self.retriable_exceptions = retriable_exceptions
def calculate_delay(self, attempt: int) -> float:
"""Calculate delay for given attempt using exponential backoff"""
delay = min(
self.initial_delay * (self.exponential_base ** attempt),
self.max_delay
)
if self.jitter:
# Add random jitter (0-100% of delay)
delay = delay * (0.5 + random.random() * 0.5)
return delay
def is_retriable(self, exception: Exception) -> bool:
"""Check if exception should trigger retry"""
return isinstance(exception, self.retriable_exceptions)
async def retry_async(
func: Callable,
*args,
strategy: Optional[RetryStrategy] = None,
**kwargs
) -> Any:
"""
Retry async function with exponential backoff.
Args:
func: Async function to retry
*args: Positional arguments for func
strategy: Retry strategy (uses default if None)
**kwargs: Keyword arguments for func
Returns:
Result from func
Raises:
RetryError: When all attempts exhausted
"""
if strategy is None:
strategy = RetryStrategy()
last_exception = None
for attempt in range(strategy.max_attempts):
try:
result = await func(*args, **kwargs)
if attempt > 0:
logger.info(
f"Retry succeeded on attempt {attempt + 1}",
function=func.__name__,
attempt=attempt + 1
)
return result
except Exception as e:
last_exception = e
if not strategy.is_retriable(e):
logger.error(
f"Non-retriable exception occurred",
function=func.__name__,
exception=str(e)
)
raise
if attempt < strategy.max_attempts - 1:
delay = strategy.calculate_delay(attempt)
logger.warning(
f"Attempt {attempt + 1} failed, retrying in {delay:.2f}s",
function=func.__name__,
attempt=attempt + 1,
max_attempts=strategy.max_attempts,
exception=str(e)
)
await asyncio.sleep(delay)
else:
logger.error(
f"All {strategy.max_attempts} retry attempts exhausted",
function=func.__name__,
exception=str(e)
)
raise RetryError(
f"Failed after {strategy.max_attempts} attempts: {str(last_exception)}",
attempts=strategy.max_attempts,
last_exception=last_exception
)
def with_retry(
max_attempts: int = 3,
initial_delay: float = 1.0,
max_delay: float = 60.0,
exponential_base: float = 2.0,
jitter: bool = True,
retriable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
):
"""
Decorator to add retry logic to async functions.
Example:
@with_retry(max_attempts=5, initial_delay=2.0)
async def fetch_data():
# Your code here
pass
"""
strategy = RetryStrategy(
max_attempts=max_attempts,
initial_delay=initial_delay,
max_delay=max_delay,
exponential_base=exponential_base,
jitter=jitter,
retriable_exceptions=retriable_exceptions
)
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
return await retry_async(func, *args, strategy=strategy, **kwargs)
return wrapper
return decorator
class AdaptiveRetryStrategy(RetryStrategy):
"""
Adaptive retry strategy that adjusts based on success/failure patterns.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.success_count = 0
self.failure_count = 0
self.consecutive_failures = 0
def calculate_delay(self, attempt: int) -> float:
"""Calculate delay with adaptation based on recent history"""
base_delay = super().calculate_delay(attempt)
# Increase delay if seeing consecutive failures
if self.consecutive_failures > 5:
multiplier = min(2.0, 1.0 + (self.consecutive_failures - 5) * 0.2)
base_delay *= multiplier
return min(base_delay, self.max_delay)
def record_success(self):
"""Record successful attempt"""
self.success_count += 1
self.consecutive_failures = 0
def record_failure(self):
"""Record failed attempt"""
self.failure_count += 1
self.consecutive_failures += 1
class TimeoutRetryStrategy(RetryStrategy):
"""
Retry strategy with overall timeout across all attempts.
"""
def __init__(self, *args, timeout: float = 300.0, **kwargs):
"""
Args:
timeout: Total timeout in seconds for all attempts
"""
super().__init__(*args, **kwargs)
self.timeout = timeout
self.start_time: Optional[float] = None
def should_retry(self, attempt: int) -> bool:
"""Check if should attempt another retry"""
if self.start_time is None:
self.start_time = time.time()
return True
elapsed = time.time() - self.start_time
return elapsed < self.timeout and attempt < self.max_attempts
async def retry_with_timeout(
func: Callable,
*args,
max_attempts: int = 3,
timeout: float = 300.0,
**kwargs
) -> Any:
"""
Retry with overall timeout.
Args:
func: Function to retry
max_attempts: Maximum attempts
timeout: Overall timeout in seconds
Returns:
Result from func
"""
strategy = TimeoutRetryStrategy(
max_attempts=max_attempts,
timeout=timeout
)
start_time = time.time()
strategy.start_time = start_time
last_exception = None
for attempt in range(strategy.max_attempts):
if time.time() - start_time >= timeout:
raise RetryError(
f"Timeout of {timeout}s exceeded",
attempts=attempt + 1,
last_exception=last_exception
)
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
if not strategy.is_retriable(e):
raise
if attempt < strategy.max_attempts - 1:
delay = strategy.calculate_delay(attempt)
await asyncio.sleep(delay)
raise RetryError(
f"Failed after {strategy.max_attempts} attempts",
attempts=strategy.max_attempts,
last_exception=last_exception
)
# Pre-configured strategies for common use cases
HTTP_RETRY_STRATEGY = RetryStrategy(
max_attempts=3,
initial_delay=1.0,
max_delay=10.0,
exponential_base=2.0,
jitter=True
)
DATABASE_RETRY_STRATEGY = RetryStrategy(
max_attempts=5,
initial_delay=0.5,
max_delay=5.0,
exponential_base=1.5,
jitter=True
)
EXTERNAL_SERVICE_RETRY_STRATEGY = RetryStrategy(
max_attempts=4,
initial_delay=2.0,
max_delay=30.0,
exponential_base=2.5,
jitter=True
)

View File

@@ -0,0 +1,184 @@
"""
Timezone Utility Functions
Centralized timezone handling to ensure consistency across the training service
"""
from datetime import datetime, timezone
from typing import Optional, Union
import pandas as pd
import logging
logger = logging.getLogger(__name__)
def ensure_timezone_aware(dt: datetime, default_tz=timezone.utc) -> datetime:
"""
Ensure a datetime is timezone-aware.
Args:
dt: Datetime to check
default_tz: Timezone to apply if datetime is naive (default: UTC)
Returns:
Timezone-aware datetime
"""
if dt is None:
return None
if dt.tzinfo is None:
return dt.replace(tzinfo=default_tz)
return dt
def ensure_timezone_naive(dt: datetime) -> datetime:
"""
Remove timezone information from a datetime.
Args:
dt: Datetime to process
Returns:
Timezone-naive datetime
"""
if dt is None:
return None
if dt.tzinfo is not None:
return dt.replace(tzinfo=None)
return dt
def normalize_datetime_to_utc(dt: Union[datetime, pd.Timestamp]) -> datetime:
"""
Normalize any datetime to UTC timezone-aware datetime.
Args:
dt: Datetime or pandas Timestamp to normalize
Returns:
UTC timezone-aware datetime
"""
if dt is None:
return None
# Handle pandas Timestamp
if isinstance(dt, pd.Timestamp):
dt = dt.to_pydatetime()
# If naive, assume UTC
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
# If aware but not UTC, convert to UTC
return dt.astimezone(timezone.utc)
def normalize_dataframe_datetime_column(
df: pd.DataFrame,
column: str,
target_format: str = 'naive'
) -> pd.DataFrame:
"""
Normalize a datetime column in a dataframe to consistent format.
Args:
df: DataFrame to process
column: Name of datetime column
target_format: 'naive' or 'aware' (UTC)
Returns:
DataFrame with normalized datetime column
"""
if column not in df.columns:
logger.warning(f"Column {column} not found in dataframe")
return df
# Convert to datetime if not already
df[column] = pd.to_datetime(df[column])
if target_format == 'naive':
# Remove timezone if present
if df[column].dt.tz is not None:
df[column] = df[column].dt.tz_localize(None)
elif target_format == 'aware':
# Add UTC timezone if not present
if df[column].dt.tz is None:
df[column] = df[column].dt.tz_localize(timezone.utc)
else:
# Convert to UTC if different timezone
df[column] = df[column].dt.tz_convert(timezone.utc)
else:
raise ValueError(f"Invalid target_format: {target_format}. Must be 'naive' or 'aware'")
return df
def prepare_prophet_datetime(df: pd.DataFrame, datetime_col: str = 'ds') -> pd.DataFrame:
"""
Prepare datetime column for Prophet (requires timezone-naive datetimes).
Args:
df: DataFrame with datetime column
datetime_col: Name of datetime column (default: 'ds')
Returns:
DataFrame with Prophet-compatible datetime column
"""
df = df.copy()
df = normalize_dataframe_datetime_column(df, datetime_col, target_format='naive')
return df
def safe_datetime_comparison(dt1: datetime, dt2: datetime) -> int:
"""
Safely compare two datetimes, handling timezone mismatches.
Args:
dt1: First datetime
dt2: Second datetime
Returns:
-1 if dt1 < dt2, 0 if equal, 1 if dt1 > dt2
"""
# Normalize both to UTC for comparison
dt1_utc = normalize_datetime_to_utc(dt1)
dt2_utc = normalize_datetime_to_utc(dt2)
if dt1_utc < dt2_utc:
return -1
elif dt1_utc > dt2_utc:
return 1
else:
return 0
def get_current_utc() -> datetime:
"""
Get current datetime in UTC with timezone awareness.
Returns:
Current UTC datetime
"""
return datetime.now(timezone.utc)
def convert_timestamp_to_datetime(timestamp: Union[int, float, str]) -> datetime:
"""
Convert various timestamp formats to datetime.
Args:
timestamp: Unix timestamp (seconds or milliseconds) or ISO string
Returns:
UTC timezone-aware datetime
"""
if isinstance(timestamp, str):
dt = pd.to_datetime(timestamp)
return normalize_datetime_to_utc(dt)
# Check if milliseconds (typical JavaScript timestamp)
if timestamp > 1e10:
timestamp = timestamp / 1000
dt = datetime.fromtimestamp(timestamp, tz=timezone.utc)
return dt