199 lines
6.4 KiB
Python
199 lines
6.4 KiB
Python
|
|
"""
|
||
|
|
Circuit Breaker Pattern Implementation
|
||
|
|
Protects against cascading failures from external service calls
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import time
|
||
|
|
from enum import Enum
|
||
|
|
from typing import Callable, Any, Optional
|
||
|
|
import logging
|
||
|
|
from functools import wraps
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class CircuitState(Enum):
|
||
|
|
"""Circuit breaker states"""
|
||
|
|
CLOSED = "closed" # Normal operation
|
||
|
|
OPEN = "open" # Circuit is open, rejecting requests
|
||
|
|
HALF_OPEN = "half_open" # Testing if service recovered
|
||
|
|
|
||
|
|
|
||
|
|
class CircuitBreakerError(Exception):
|
||
|
|
"""Raised when circuit breaker is open"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class CircuitBreaker:
|
||
|
|
"""
|
||
|
|
Circuit breaker to prevent cascading failures.
|
||
|
|
|
||
|
|
States:
|
||
|
|
- CLOSED: Normal operation, requests pass through
|
||
|
|
- OPEN: Too many failures, rejecting all requests
|
||
|
|
- HALF_OPEN: Testing if service recovered, allowing limited requests
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
failure_threshold: int = 5,
|
||
|
|
recovery_timeout: float = 60.0,
|
||
|
|
expected_exception: type = Exception,
|
||
|
|
name: str = "circuit_breaker"
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Initialize circuit breaker.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
failure_threshold: Number of failures before opening circuit
|
||
|
|
recovery_timeout: Seconds to wait before attempting recovery
|
||
|
|
expected_exception: Exception type to catch (others will pass through)
|
||
|
|
name: Name for logging purposes
|
||
|
|
"""
|
||
|
|
self.failure_threshold = failure_threshold
|
||
|
|
self.recovery_timeout = recovery_timeout
|
||
|
|
self.expected_exception = expected_exception
|
||
|
|
self.name = name
|
||
|
|
|
||
|
|
self.failure_count = 0
|
||
|
|
self.last_failure_time: Optional[float] = None
|
||
|
|
self.state = CircuitState.CLOSED
|
||
|
|
|
||
|
|
def _record_success(self):
|
||
|
|
"""Record successful call"""
|
||
|
|
self.failure_count = 0
|
||
|
|
self.last_failure_time = None
|
||
|
|
if self.state == CircuitState.HALF_OPEN:
|
||
|
|
logger.info(f"Circuit breaker '{self.name}' recovered, closing circuit")
|
||
|
|
self.state = CircuitState.CLOSED
|
||
|
|
|
||
|
|
def _record_failure(self):
|
||
|
|
"""Record failed call"""
|
||
|
|
self.failure_count += 1
|
||
|
|
self.last_failure_time = time.time()
|
||
|
|
|
||
|
|
if self.failure_count >= self.failure_threshold:
|
||
|
|
if self.state != CircuitState.OPEN:
|
||
|
|
logger.warning(
|
||
|
|
f"Circuit breaker '{self.name}' opened after {self.failure_count} failures"
|
||
|
|
)
|
||
|
|
self.state = CircuitState.OPEN
|
||
|
|
|
||
|
|
def _should_attempt_reset(self) -> bool:
|
||
|
|
"""Check if we should attempt to reset circuit"""
|
||
|
|
return (
|
||
|
|
self.state == CircuitState.OPEN
|
||
|
|
and self.last_failure_time is not None
|
||
|
|
and time.time() - self.last_failure_time >= self.recovery_timeout
|
||
|
|
)
|
||
|
|
|
||
|
|
async def call(self, func: Callable, *args, **kwargs) -> Any:
|
||
|
|
"""
|
||
|
|
Execute function with circuit breaker protection.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
func: Async function to execute
|
||
|
|
*args: Positional arguments for func
|
||
|
|
**kwargs: Keyword arguments for func
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Result from func
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
CircuitBreakerError: If circuit is open
|
||
|
|
Exception: Original exception if not expected_exception type
|
||
|
|
"""
|
||
|
|
# Check if circuit is open
|
||
|
|
if self.state == CircuitState.OPEN:
|
||
|
|
if self._should_attempt_reset():
|
||
|
|
logger.info(f"Circuit breaker '{self.name}' attempting recovery (half-open)")
|
||
|
|
self.state = CircuitState.HALF_OPEN
|
||
|
|
else:
|
||
|
|
raise CircuitBreakerError(
|
||
|
|
f"Circuit breaker '{self.name}' is open. "
|
||
|
|
f"Service unavailable for {self.recovery_timeout}s after {self.failure_count} failures."
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Execute the function
|
||
|
|
result = await func(*args, **kwargs)
|
||
|
|
self._record_success()
|
||
|
|
return result
|
||
|
|
|
||
|
|
except self.expected_exception as e:
|
||
|
|
self._record_failure()
|
||
|
|
logger.error(
|
||
|
|
f"Circuit breaker '{self.name}' caught failure",
|
||
|
|
error=str(e),
|
||
|
|
failure_count=self.failure_count,
|
||
|
|
state=self.state.value
|
||
|
|
)
|
||
|
|
raise
|
||
|
|
|
||
|
|
def __call__(self, func: Callable) -> Callable:
|
||
|
|
"""Decorator interface for circuit breaker"""
|
||
|
|
@wraps(func)
|
||
|
|
async def wrapper(*args, **kwargs):
|
||
|
|
return await self.call(func, *args, **kwargs)
|
||
|
|
return wrapper
|
||
|
|
|
||
|
|
def get_state(self) -> dict:
|
||
|
|
"""Get current circuit breaker state for monitoring"""
|
||
|
|
return {
|
||
|
|
"name": self.name,
|
||
|
|
"state": self.state.value,
|
||
|
|
"failure_count": self.failure_count,
|
||
|
|
"failure_threshold": self.failure_threshold,
|
||
|
|
"last_failure_time": self.last_failure_time,
|
||
|
|
"recovery_timeout": self.recovery_timeout
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
class CircuitBreakerRegistry:
|
||
|
|
"""Registry to manage multiple circuit breakers"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self._breakers: dict[str, CircuitBreaker] = {}
|
||
|
|
|
||
|
|
def get_or_create(
|
||
|
|
self,
|
||
|
|
name: str,
|
||
|
|
failure_threshold: int = 5,
|
||
|
|
recovery_timeout: float = 60.0,
|
||
|
|
expected_exception: type = Exception
|
||
|
|
) -> CircuitBreaker:
|
||
|
|
"""Get existing circuit breaker or create new one"""
|
||
|
|
if name not in self._breakers:
|
||
|
|
self._breakers[name] = CircuitBreaker(
|
||
|
|
failure_threshold=failure_threshold,
|
||
|
|
recovery_timeout=recovery_timeout,
|
||
|
|
expected_exception=expected_exception,
|
||
|
|
name=name
|
||
|
|
)
|
||
|
|
return self._breakers[name]
|
||
|
|
|
||
|
|
def get(self, name: str) -> Optional[CircuitBreaker]:
|
||
|
|
"""Get circuit breaker by name"""
|
||
|
|
return self._breakers.get(name)
|
||
|
|
|
||
|
|
def get_all_states(self) -> dict:
|
||
|
|
"""Get states of all circuit breakers"""
|
||
|
|
return {
|
||
|
|
name: breaker.get_state()
|
||
|
|
for name, breaker in self._breakers.items()
|
||
|
|
}
|
||
|
|
|
||
|
|
def reset(self, name: str):
|
||
|
|
"""Manually reset a circuit breaker"""
|
||
|
|
if name in self._breakers:
|
||
|
|
breaker = self._breakers[name]
|
||
|
|
breaker.failure_count = 0
|
||
|
|
breaker.last_failure_time = None
|
||
|
|
breaker.state = CircuitState.CLOSED
|
||
|
|
logger.info(f"Circuit breaker '{name}' manually reset")
|
||
|
|
|
||
|
|
|
||
|
|
# Global registry instance
|
||
|
|
circuit_breaker_registry = CircuitBreakerRegistry()
|