Files
bakery-ia/services/training/app/utils/circuit_breaker.py

199 lines
6.4 KiB
Python
Raw Normal View History

"""
Circuit Breaker Pattern Implementation
Protects against cascading failures from external service calls
"""
import asyncio
import time
from enum import Enum
from typing import Callable, Any, Optional
import logging
from functools import wraps
logger = logging.getLogger(__name__)
class CircuitState(Enum):
"""Circuit breaker states"""
CLOSED = "closed" # Normal operation
OPEN = "open" # Circuit is open, rejecting requests
HALF_OPEN = "half_open" # Testing if service recovered
class CircuitBreakerError(Exception):
"""Raised when circuit breaker is open"""
pass
class CircuitBreaker:
"""
Circuit breaker to prevent cascading failures.
States:
- CLOSED: Normal operation, requests pass through
- OPEN: Too many failures, rejecting all requests
- HALF_OPEN: Testing if service recovered, allowing limited requests
"""
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: float = 60.0,
expected_exception: type = Exception,
name: str = "circuit_breaker"
):
"""
Initialize circuit breaker.
Args:
failure_threshold: Number of failures before opening circuit
recovery_timeout: Seconds to wait before attempting recovery
expected_exception: Exception type to catch (others will pass through)
name: Name for logging purposes
"""
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.expected_exception = expected_exception
self.name = name
self.failure_count = 0
self.last_failure_time: Optional[float] = None
self.state = CircuitState.CLOSED
def _record_success(self):
"""Record successful call"""
self.failure_count = 0
self.last_failure_time = None
if self.state == CircuitState.HALF_OPEN:
logger.info(f"Circuit breaker '{self.name}' recovered, closing circuit")
self.state = CircuitState.CLOSED
def _record_failure(self):
"""Record failed call"""
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold:
if self.state != CircuitState.OPEN:
logger.warning(
f"Circuit breaker '{self.name}' opened after {self.failure_count} failures"
)
self.state = CircuitState.OPEN
def _should_attempt_reset(self) -> bool:
"""Check if we should attempt to reset circuit"""
return (
self.state == CircuitState.OPEN
and self.last_failure_time is not None
and time.time() - self.last_failure_time >= self.recovery_timeout
)
async def call(self, func: Callable, *args, **kwargs) -> Any:
"""
Execute function with circuit breaker protection.
Args:
func: Async function to execute
*args: Positional arguments for func
**kwargs: Keyword arguments for func
Returns:
Result from func
Raises:
CircuitBreakerError: If circuit is open
Exception: Original exception if not expected_exception type
"""
# Check if circuit is open
if self.state == CircuitState.OPEN:
if self._should_attempt_reset():
logger.info(f"Circuit breaker '{self.name}' attempting recovery (half-open)")
self.state = CircuitState.HALF_OPEN
else:
raise CircuitBreakerError(
f"Circuit breaker '{self.name}' is open. "
f"Service unavailable for {self.recovery_timeout}s after {self.failure_count} failures."
)
try:
# Execute the function
result = await func(*args, **kwargs)
self._record_success()
return result
except self.expected_exception as e:
self._record_failure()
logger.error(
f"Circuit breaker '{self.name}' caught failure",
error=str(e),
failure_count=self.failure_count,
state=self.state.value
)
raise
def __call__(self, func: Callable) -> Callable:
"""Decorator interface for circuit breaker"""
@wraps(func)
async def wrapper(*args, **kwargs):
return await self.call(func, *args, **kwargs)
return wrapper
def get_state(self) -> dict:
"""Get current circuit breaker state for monitoring"""
return {
"name": self.name,
"state": self.state.value,
"failure_count": self.failure_count,
"failure_threshold": self.failure_threshold,
"last_failure_time": self.last_failure_time,
"recovery_timeout": self.recovery_timeout
}
class CircuitBreakerRegistry:
"""Registry to manage multiple circuit breakers"""
def __init__(self):
self._breakers: dict[str, CircuitBreaker] = {}
def get_or_create(
self,
name: str,
failure_threshold: int = 5,
recovery_timeout: float = 60.0,
expected_exception: type = Exception
) -> CircuitBreaker:
"""Get existing circuit breaker or create new one"""
if name not in self._breakers:
self._breakers[name] = CircuitBreaker(
failure_threshold=failure_threshold,
recovery_timeout=recovery_timeout,
expected_exception=expected_exception,
name=name
)
return self._breakers[name]
def get(self, name: str) -> Optional[CircuitBreaker]:
"""Get circuit breaker by name"""
return self._breakers.get(name)
def get_all_states(self) -> dict:
"""Get states of all circuit breakers"""
return {
name: breaker.get_state()
for name, breaker in self._breakers.items()
}
def reset(self, name: str):
"""Manually reset a circuit breaker"""
if name in self._breakers:
breaker = self._breakers[name]
breaker.failure_count = 0
breaker.last_failure_time = None
breaker.state = CircuitState.CLOSED
logger.info(f"Circuit breaker '{name}' manually reset")
# Global registry instance
circuit_breaker_registry = CircuitBreakerRegistry()