""" Circuit Breaker Pattern Implementation Protects against cascading failures from external service calls """ import asyncio import time from enum import Enum from typing import Callable, Any, Optional import logging from functools import wraps logger = logging.getLogger(__name__) class CircuitState(Enum): """Circuit breaker states""" CLOSED = "closed" # Normal operation OPEN = "open" # Circuit is open, rejecting requests HALF_OPEN = "half_open" # Testing if service recovered class CircuitBreakerError(Exception): """Raised when circuit breaker is open""" pass class CircuitBreaker: """ Circuit breaker to prevent cascading failures. States: - CLOSED: Normal operation, requests pass through - OPEN: Too many failures, rejecting all requests - HALF_OPEN: Testing if service recovered, allowing limited requests """ def __init__( self, failure_threshold: int = 5, recovery_timeout: float = 60.0, expected_exception: type = Exception, name: str = "circuit_breaker" ): """ Initialize circuit breaker. Args: failure_threshold: Number of failures before opening circuit recovery_timeout: Seconds to wait before attempting recovery expected_exception: Exception type to catch (others will pass through) name: Name for logging purposes """ self.failure_threshold = failure_threshold self.recovery_timeout = recovery_timeout self.expected_exception = expected_exception self.name = name self.failure_count = 0 self.last_failure_time: Optional[float] = None self.state = CircuitState.CLOSED def _record_success(self): """Record successful call""" self.failure_count = 0 self.last_failure_time = None if self.state == CircuitState.HALF_OPEN: logger.info(f"Circuit breaker '{self.name}' recovered, closing circuit") self.state = CircuitState.CLOSED def _record_failure(self): """Record failed call""" self.failure_count += 1 self.last_failure_time = time.time() if self.failure_count >= self.failure_threshold: if self.state != CircuitState.OPEN: logger.warning( f"Circuit breaker '{self.name}' opened after {self.failure_count} failures" ) self.state = CircuitState.OPEN def _should_attempt_reset(self) -> bool: """Check if we should attempt to reset circuit""" return ( self.state == CircuitState.OPEN and self.last_failure_time is not None and time.time() - self.last_failure_time >= self.recovery_timeout ) async def call(self, func: Callable, *args, **kwargs) -> Any: """ Execute function with circuit breaker protection. Args: func: Async function to execute *args: Positional arguments for func **kwargs: Keyword arguments for func Returns: Result from func Raises: CircuitBreakerError: If circuit is open Exception: Original exception if not expected_exception type """ # Check if circuit is open if self.state == CircuitState.OPEN: if self._should_attempt_reset(): logger.info(f"Circuit breaker '{self.name}' attempting recovery (half-open)") self.state = CircuitState.HALF_OPEN else: raise CircuitBreakerError( f"Circuit breaker '{self.name}' is open. " f"Service unavailable for {self.recovery_timeout}s after {self.failure_count} failures." ) try: # Execute the function result = await func(*args, **kwargs) self._record_success() return result except self.expected_exception as e: self._record_failure() logger.error( f"Circuit breaker '{self.name}' caught failure", error=str(e), failure_count=self.failure_count, state=self.state.value ) raise def __call__(self, func: Callable) -> Callable: """Decorator interface for circuit breaker""" @wraps(func) async def wrapper(*args, **kwargs): return await self.call(func, *args, **kwargs) return wrapper def get_state(self) -> dict: """Get current circuit breaker state for monitoring""" return { "name": self.name, "state": self.state.value, "failure_count": self.failure_count, "failure_threshold": self.failure_threshold, "last_failure_time": self.last_failure_time, "recovery_timeout": self.recovery_timeout } class CircuitBreakerRegistry: """Registry to manage multiple circuit breakers""" def __init__(self): self._breakers: dict[str, CircuitBreaker] = {} def get_or_create( self, name: str, failure_threshold: int = 5, recovery_timeout: float = 60.0, expected_exception: type = Exception ) -> CircuitBreaker: """Get existing circuit breaker or create new one""" if name not in self._breakers: self._breakers[name] = CircuitBreaker( failure_threshold=failure_threshold, recovery_timeout=recovery_timeout, expected_exception=expected_exception, name=name ) return self._breakers[name] def get(self, name: str) -> Optional[CircuitBreaker]: """Get circuit breaker by name""" return self._breakers.get(name) def get_all_states(self) -> dict: """Get states of all circuit breakers""" return { name: breaker.get_state() for name, breaker in self._breakers.items() } def reset(self, name: str): """Manually reset a circuit breaker""" if name in self._breakers: breaker = self._breakers[name] breaker.failure_count = 0 breaker.last_failure_time = None breaker.state = CircuitState.CLOSED logger.info(f"Circuit breaker '{name}' manually reset") # Global registry instance circuit_breaker_registry = CircuitBreakerRegistry()