169 lines
5.5 KiB
Python
Executable File
169 lines
5.5 KiB
Python
Executable File
"""
|
|
Circuit Breaker Pattern Implementation
|
|
|
|
Prevents cascading failures by stopping requests to failing services
|
|
and allowing them time to recover.
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
from enum import Enum
|
|
from typing import Callable, Any, Optional
|
|
from datetime import datetime, timedelta
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CircuitState(str, Enum):
|
|
"""Circuit breaker states"""
|
|
CLOSED = "closed" # Normal operation
|
|
OPEN = "open" # Circuit is open, requests fail immediately
|
|
HALF_OPEN = "half_open" # Testing if service has recovered
|
|
|
|
|
|
class CircuitBreakerOpenError(Exception):
|
|
"""Raised when circuit breaker is open"""
|
|
pass
|
|
|
|
|
|
class CircuitBreaker:
|
|
"""
|
|
Circuit Breaker implementation for protecting service calls.
|
|
|
|
States:
|
|
- CLOSED: Normal operation, requests pass through
|
|
- OPEN: Too many failures, requests fail immediately
|
|
- HALF_OPEN: Testing recovery, limited requests allowed
|
|
|
|
Args:
|
|
failure_threshold: Number of failures before opening circuit
|
|
timeout_duration: Seconds to wait before attempting recovery
|
|
success_threshold: Successful calls needed in HALF_OPEN to close circuit
|
|
expected_exceptions: Tuple of exceptions that count as failures
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
failure_threshold: int = 5,
|
|
timeout_duration: int = 60,
|
|
success_threshold: int = 2,
|
|
expected_exceptions: tuple = (Exception,)
|
|
):
|
|
self.failure_threshold = failure_threshold
|
|
self.timeout_duration = timeout_duration
|
|
self.success_threshold = success_threshold
|
|
self.expected_exceptions = expected_exceptions
|
|
|
|
self._state = CircuitState.CLOSED
|
|
self._failure_count = 0
|
|
self._success_count = 0
|
|
self._last_failure_time: Optional[datetime] = None
|
|
self._next_attempt_time: Optional[datetime] = None
|
|
|
|
@property
|
|
def state(self) -> CircuitState:
|
|
"""Get current circuit state"""
|
|
if self._state == CircuitState.OPEN and self._should_attempt_reset():
|
|
self._state = CircuitState.HALF_OPEN
|
|
self._success_count = 0
|
|
logger.info(f"Circuit breaker entering HALF_OPEN state")
|
|
return self._state
|
|
|
|
def _should_attempt_reset(self) -> bool:
|
|
"""Check if enough time has passed to attempt reset"""
|
|
if self._next_attempt_time is None:
|
|
return False
|
|
return datetime.now() >= self._next_attempt_time
|
|
|
|
async def call(self, func: Callable, *args, **kwargs) -> Any:
|
|
"""
|
|
Execute function with circuit breaker protection.
|
|
|
|
Args:
|
|
func: Function to execute
|
|
*args: Positional arguments for func
|
|
**kwargs: Keyword arguments for func
|
|
|
|
Returns:
|
|
Result of func execution
|
|
|
|
Raises:
|
|
CircuitBreakerOpenError: If circuit is open
|
|
Exception: Original exception from func if circuit is closed
|
|
"""
|
|
if self.state == CircuitState.OPEN:
|
|
raise CircuitBreakerOpenError(
|
|
f"Circuit breaker is OPEN. Next attempt at {self._next_attempt_time}"
|
|
)
|
|
|
|
try:
|
|
# Execute the function
|
|
if asyncio.iscoroutinefunction(func):
|
|
result = await func(*args, **kwargs)
|
|
else:
|
|
result = func(*args, **kwargs)
|
|
|
|
# Success
|
|
self._on_success()
|
|
return result
|
|
|
|
except self.expected_exceptions as e:
|
|
# Expected failure
|
|
self._on_failure()
|
|
raise
|
|
|
|
def _on_success(self):
|
|
"""Handle successful call"""
|
|
if self._state == CircuitState.HALF_OPEN:
|
|
self._success_count += 1
|
|
if self._success_count >= self.success_threshold:
|
|
self._close_circuit()
|
|
else:
|
|
# In CLOSED state, reset failure count on success
|
|
self._failure_count = 0
|
|
|
|
def _on_failure(self):
|
|
"""Handle failed call"""
|
|
self._failure_count += 1
|
|
self._last_failure_time = datetime.now()
|
|
|
|
if self._state == CircuitState.HALF_OPEN:
|
|
# Failure in HALF_OPEN returns to OPEN
|
|
self._open_circuit()
|
|
elif self._failure_count >= self.failure_threshold:
|
|
# Too many failures, open the circuit
|
|
self._open_circuit()
|
|
|
|
def _open_circuit(self):
|
|
"""Open the circuit"""
|
|
self._state = CircuitState.OPEN
|
|
self._next_attempt_time = datetime.now() + timedelta(seconds=self.timeout_duration)
|
|
logger.warning(
|
|
f"Circuit breaker opened after {self._failure_count} failures. "
|
|
f"Next attempt at {self._next_attempt_time}"
|
|
)
|
|
|
|
def _close_circuit(self):
|
|
"""Close the circuit"""
|
|
self._state = CircuitState.CLOSED
|
|
self._failure_count = 0
|
|
self._success_count = 0
|
|
self._next_attempt_time = None
|
|
logger.info(f"Circuit breaker closed after successful recovery")
|
|
|
|
def reset(self):
|
|
"""Manually reset circuit breaker to CLOSED state"""
|
|
self._close_circuit()
|
|
logger.info(f"Circuit breaker manually reset")
|
|
|
|
def get_stats(self) -> dict:
|
|
"""Get circuit breaker statistics"""
|
|
return {
|
|
"state": self.state.value,
|
|
"failure_count": self._failure_count,
|
|
"success_count": self._success_count,
|
|
"last_failure_time": self._last_failure_time.isoformat() if self._last_failure_time else None,
|
|
"next_attempt_time": self._next_attempt_time.isoformat() if self._next_attempt_time else None
|
|
}
|