317 lines
8.6 KiB
Python
317 lines
8.6 KiB
Python
"""
|
|
Retry Mechanism with Exponential Backoff
|
|
Handles transient failures with intelligent retry strategies
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
import random
|
|
from typing import Callable, Any, Optional, Type, Tuple
|
|
from functools import wraps
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RetryError(Exception):
|
|
"""Raised when all retry attempts are exhausted"""
|
|
def __init__(self, message: str, attempts: int, last_exception: Exception):
|
|
super().__init__(message)
|
|
self.attempts = attempts
|
|
self.last_exception = last_exception
|
|
|
|
|
|
class RetryStrategy:
|
|
"""Base retry strategy"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_attempts: int = 3,
|
|
initial_delay: float = 1.0,
|
|
max_delay: float = 60.0,
|
|
exponential_base: float = 2.0,
|
|
jitter: bool = True,
|
|
retriable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
|
|
):
|
|
"""
|
|
Initialize retry strategy.
|
|
|
|
Args:
|
|
max_attempts: Maximum number of retry attempts
|
|
initial_delay: Initial delay in seconds
|
|
max_delay: Maximum delay between retries
|
|
exponential_base: Base for exponential backoff
|
|
jitter: Add random jitter to prevent thundering herd
|
|
retriable_exceptions: Tuple of exception types to retry
|
|
"""
|
|
self.max_attempts = max_attempts
|
|
self.initial_delay = initial_delay
|
|
self.max_delay = max_delay
|
|
self.exponential_base = exponential_base
|
|
self.jitter = jitter
|
|
self.retriable_exceptions = retriable_exceptions
|
|
|
|
def calculate_delay(self, attempt: int) -> float:
|
|
"""Calculate delay for given attempt using exponential backoff"""
|
|
delay = min(
|
|
self.initial_delay * (self.exponential_base ** attempt),
|
|
self.max_delay
|
|
)
|
|
|
|
if self.jitter:
|
|
# Add random jitter (0-100% of delay)
|
|
delay = delay * (0.5 + random.random() * 0.5)
|
|
|
|
return delay
|
|
|
|
def is_retriable(self, exception: Exception) -> bool:
|
|
"""Check if exception should trigger retry"""
|
|
return isinstance(exception, self.retriable_exceptions)
|
|
|
|
|
|
async def retry_async(
|
|
func: Callable,
|
|
*args,
|
|
strategy: Optional[RetryStrategy] = None,
|
|
**kwargs
|
|
) -> Any:
|
|
"""
|
|
Retry async function with exponential backoff.
|
|
|
|
Args:
|
|
func: Async function to retry
|
|
*args: Positional arguments for func
|
|
strategy: Retry strategy (uses default if None)
|
|
**kwargs: Keyword arguments for func
|
|
|
|
Returns:
|
|
Result from func
|
|
|
|
Raises:
|
|
RetryError: When all attempts exhausted
|
|
"""
|
|
if strategy is None:
|
|
strategy = RetryStrategy()
|
|
|
|
last_exception = None
|
|
|
|
for attempt in range(strategy.max_attempts):
|
|
try:
|
|
result = await func(*args, **kwargs)
|
|
|
|
if attempt > 0:
|
|
logger.info(
|
|
f"Retry succeeded on attempt {attempt + 1}",
|
|
function=func.__name__,
|
|
attempt=attempt + 1
|
|
)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
last_exception = e
|
|
|
|
if not strategy.is_retriable(e):
|
|
logger.error(
|
|
f"Non-retriable exception occurred",
|
|
function=func.__name__,
|
|
exception=str(e)
|
|
)
|
|
raise
|
|
|
|
if attempt < strategy.max_attempts - 1:
|
|
delay = strategy.calculate_delay(attempt)
|
|
logger.warning(
|
|
f"Attempt {attempt + 1} failed, retrying in {delay:.2f}s",
|
|
function=func.__name__,
|
|
attempt=attempt + 1,
|
|
max_attempts=strategy.max_attempts,
|
|
exception=str(e)
|
|
)
|
|
await asyncio.sleep(delay)
|
|
else:
|
|
logger.error(
|
|
f"All {strategy.max_attempts} retry attempts exhausted",
|
|
function=func.__name__,
|
|
exception=str(e)
|
|
)
|
|
|
|
raise RetryError(
|
|
f"Failed after {strategy.max_attempts} attempts: {str(last_exception)}",
|
|
attempts=strategy.max_attempts,
|
|
last_exception=last_exception
|
|
)
|
|
|
|
|
|
def with_retry(
|
|
max_attempts: int = 3,
|
|
initial_delay: float = 1.0,
|
|
max_delay: float = 60.0,
|
|
exponential_base: float = 2.0,
|
|
jitter: bool = True,
|
|
retriable_exceptions: Tuple[Type[Exception], ...] = (Exception,)
|
|
):
|
|
"""
|
|
Decorator to add retry logic to async functions.
|
|
|
|
Example:
|
|
@with_retry(max_attempts=5, initial_delay=2.0)
|
|
async def fetch_data():
|
|
# Your code here
|
|
pass
|
|
"""
|
|
strategy = RetryStrategy(
|
|
max_attempts=max_attempts,
|
|
initial_delay=initial_delay,
|
|
max_delay=max_delay,
|
|
exponential_base=exponential_base,
|
|
jitter=jitter,
|
|
retriable_exceptions=retriable_exceptions
|
|
)
|
|
|
|
def decorator(func: Callable):
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
return await retry_async(func, *args, strategy=strategy, **kwargs)
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
class AdaptiveRetryStrategy(RetryStrategy):
|
|
"""
|
|
Adaptive retry strategy that adjusts based on success/failure patterns.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.success_count = 0
|
|
self.failure_count = 0
|
|
self.consecutive_failures = 0
|
|
|
|
def calculate_delay(self, attempt: int) -> float:
|
|
"""Calculate delay with adaptation based on recent history"""
|
|
base_delay = super().calculate_delay(attempt)
|
|
|
|
# Increase delay if seeing consecutive failures
|
|
if self.consecutive_failures > 5:
|
|
multiplier = min(2.0, 1.0 + (self.consecutive_failures - 5) * 0.2)
|
|
base_delay *= multiplier
|
|
|
|
return min(base_delay, self.max_delay)
|
|
|
|
def record_success(self):
|
|
"""Record successful attempt"""
|
|
self.success_count += 1
|
|
self.consecutive_failures = 0
|
|
|
|
def record_failure(self):
|
|
"""Record failed attempt"""
|
|
self.failure_count += 1
|
|
self.consecutive_failures += 1
|
|
|
|
|
|
class TimeoutRetryStrategy(RetryStrategy):
|
|
"""
|
|
Retry strategy with overall timeout across all attempts.
|
|
"""
|
|
|
|
def __init__(self, *args, timeout: float = 300.0, **kwargs):
|
|
"""
|
|
Args:
|
|
timeout: Total timeout in seconds for all attempts
|
|
"""
|
|
super().__init__(*args, **kwargs)
|
|
self.timeout = timeout
|
|
self.start_time: Optional[float] = None
|
|
|
|
def should_retry(self, attempt: int) -> bool:
|
|
"""Check if should attempt another retry"""
|
|
if self.start_time is None:
|
|
self.start_time = time.time()
|
|
return True
|
|
|
|
elapsed = time.time() - self.start_time
|
|
return elapsed < self.timeout and attempt < self.max_attempts
|
|
|
|
|
|
async def retry_with_timeout(
|
|
func: Callable,
|
|
*args,
|
|
max_attempts: int = 3,
|
|
timeout: float = 300.0,
|
|
**kwargs
|
|
) -> Any:
|
|
"""
|
|
Retry with overall timeout.
|
|
|
|
Args:
|
|
func: Function to retry
|
|
max_attempts: Maximum attempts
|
|
timeout: Overall timeout in seconds
|
|
|
|
Returns:
|
|
Result from func
|
|
"""
|
|
strategy = TimeoutRetryStrategy(
|
|
max_attempts=max_attempts,
|
|
timeout=timeout
|
|
)
|
|
|
|
start_time = time.time()
|
|
strategy.start_time = start_time
|
|
|
|
last_exception = None
|
|
|
|
for attempt in range(strategy.max_attempts):
|
|
if time.time() - start_time >= timeout:
|
|
raise RetryError(
|
|
f"Timeout of {timeout}s exceeded",
|
|
attempts=attempt + 1,
|
|
last_exception=last_exception
|
|
)
|
|
|
|
try:
|
|
return await func(*args, **kwargs)
|
|
except Exception as e:
|
|
last_exception = e
|
|
|
|
if not strategy.is_retriable(e):
|
|
raise
|
|
|
|
if attempt < strategy.max_attempts - 1:
|
|
delay = strategy.calculate_delay(attempt)
|
|
await asyncio.sleep(delay)
|
|
|
|
raise RetryError(
|
|
f"Failed after {strategy.max_attempts} attempts",
|
|
attempts=strategy.max_attempts,
|
|
last_exception=last_exception
|
|
)
|
|
|
|
|
|
# Pre-configured strategies for common use cases
|
|
HTTP_RETRY_STRATEGY = RetryStrategy(
|
|
max_attempts=3,
|
|
initial_delay=1.0,
|
|
max_delay=10.0,
|
|
exponential_base=2.0,
|
|
jitter=True
|
|
)
|
|
|
|
DATABASE_RETRY_STRATEGY = RetryStrategy(
|
|
max_attempts=5,
|
|
initial_delay=0.5,
|
|
max_delay=5.0,
|
|
exponential_base=1.5,
|
|
jitter=True
|
|
)
|
|
|
|
EXTERNAL_SERVICE_RETRY_STRATEGY = RetryStrategy(
|
|
max_attempts=4,
|
|
initial_delay=2.0,
|
|
max_delay=30.0,
|
|
exponential_base=2.5,
|
|
jitter=True
|
|
)
|