Files
bakery-ia/shared/utils/saga_pattern.py
2025-12-13 23:57:54 +01:00

294 lines
9.4 KiB
Python
Executable File

"""
Saga Pattern Implementation
Provides distributed transaction coordination with compensation logic
for microservices architecture.
"""
import asyncio
import uuid
from typing import Callable, List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import logging
logger = logging.getLogger(__name__)
class SagaStepStatus(str, Enum):
"""Status of a saga step"""
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
COMPENSATING = "compensating"
COMPENSATED = "compensated"
class SagaStatus(str, Enum):
"""Overall saga status"""
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
COMPENSATING = "compensating"
COMPENSATED = "compensated"
@dataclass
class SagaStep:
"""
A single step in a saga with compensation logic.
Args:
name: Human-readable step name
action: Async function to execute
compensation: Async function to undo the action
action_args: Arguments for the action function
action_kwargs: Keyword arguments for the action function
"""
name: str
action: Callable
compensation: Optional[Callable] = None
action_args: tuple = field(default_factory=tuple)
action_kwargs: dict = field(default_factory=dict)
# Runtime state
status: SagaStepStatus = SagaStepStatus.PENDING
result: Any = None
error: Optional[Exception] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@dataclass
class SagaExecution:
"""Tracks execution state of a saga"""
saga_id: str
status: SagaStatus = SagaStatus.PENDING
steps: List[SagaStep] = field(default_factory=list)
current_step: int = 0
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
error: Optional[Exception] = None
class SagaCoordinator:
"""
Coordinates saga execution with automatic compensation on failure.
Example:
```python
saga = SagaCoordinator()
saga.add_step(
"create_order",
action=create_order,
compensation=delete_order,
action_args=(order_data,)
)
saga.add_step(
"reserve_inventory",
action=reserve_inventory,
compensation=release_inventory,
action_args=(order_id, items)
)
result = await saga.execute()
```
"""
def __init__(self, saga_id: Optional[str] = None):
self.execution = SagaExecution(
saga_id=saga_id or str(uuid.uuid4())
)
self._completed_steps: List[SagaStep] = []
def add_step(
self,
name: str,
action: Callable,
compensation: Optional[Callable] = None,
action_args: tuple = (),
action_kwargs: dict = None
):
"""
Add a step to the saga.
Args:
name: Human-readable step name
action: Async function to execute
compensation: Async function to undo the action (optional)
action_args: Arguments for the action function
action_kwargs: Keyword arguments for the action function
"""
step = SagaStep(
name=name,
action=action,
compensation=compensation,
action_args=action_args,
action_kwargs=action_kwargs or {}
)
self.execution.steps.append(step)
logger.debug(f"Added step '{name}' to saga {self.execution.saga_id}")
async def execute(self) -> Tuple[bool, Optional[Any], Optional[Exception]]:
"""
Execute all saga steps in sequence.
Returns:
Tuple of (success: bool, final_result: Any, error: Optional[Exception])
"""
self.execution.status = SagaStatus.IN_PROGRESS
self.execution.started_at = datetime.now()
logger.info(
f"Starting saga {self.execution.saga_id} with {len(self.execution.steps)} steps"
)
try:
# Execute each step
for idx, step in enumerate(self.execution.steps):
self.execution.current_step = idx
success = await self._execute_step(step)
if not success:
# Step failed, trigger compensation
logger.error(
f"Saga {self.execution.saga_id} failed at step '{step.name}': {step.error}"
)
await self._compensate()
self.execution.status = SagaStatus.COMPENSATED
self.execution.completed_at = datetime.now()
self.execution.error = step.error
return False, None, step.error
# Step succeeded
self._completed_steps.append(step)
# All steps completed successfully
self.execution.status = SagaStatus.COMPLETED
self.execution.completed_at = datetime.now()
# Return the result of the last step
final_result = self.execution.steps[-1].result if self.execution.steps else None
logger.info(f"Saga {self.execution.saga_id} completed successfully")
return True, final_result, None
except Exception as e:
logger.exception(f"Unexpected error in saga {self.execution.saga_id}: {e}")
await self._compensate()
self.execution.status = SagaStatus.FAILED
self.execution.completed_at = datetime.now()
self.execution.error = e
return False, None, e
async def _execute_step(self, step: SagaStep) -> bool:
"""
Execute a single saga step.
Returns:
True if step succeeded, False otherwise
"""
step.status = SagaStepStatus.IN_PROGRESS
step.started_at = datetime.now()
logger.info(f"Executing saga step '{step.name}'")
try:
# Execute the action
if asyncio.iscoroutinefunction(step.action):
result = await step.action(*step.action_args, **step.action_kwargs)
else:
result = step.action(*step.action_args, **step.action_kwargs)
step.result = result
step.status = SagaStepStatus.COMPLETED
step.completed_at = datetime.now()
logger.info(f"Saga step '{step.name}' completed successfully")
return True
except Exception as e:
step.error = e
step.status = SagaStepStatus.FAILED
step.completed_at = datetime.now()
logger.error(f"Saga step '{step.name}' failed: {e}")
return False
async def _compensate(self):
"""
Execute compensation logic for all completed steps in reverse order.
"""
if not self._completed_steps:
logger.info(f"No steps to compensate for saga {self.execution.saga_id}")
return
self.execution.status = SagaStatus.COMPENSATING
logger.info(
f"Starting compensation for saga {self.execution.saga_id} "
f"({len(self._completed_steps)} steps to compensate)"
)
# Compensate in reverse order
for step in reversed(self._completed_steps):
if step.compensation is None:
logger.warning(
f"Step '{step.name}' has no compensation function, skipping"
)
continue
step.status = SagaStepStatus.COMPENSATING
try:
logger.info(f"Compensating step '{step.name}'")
# Execute compensation with the result from the original action
compensation_args = (step.result,) if step.result is not None else ()
if asyncio.iscoroutinefunction(step.compensation):
await step.compensation(*compensation_args)
else:
step.compensation(*compensation_args)
step.status = SagaStepStatus.COMPENSATED
logger.info(f"Step '{step.name}' compensated successfully")
except Exception as e:
logger.error(f"Failed to compensate step '{step.name}': {e}")
# Continue compensating other steps even if one fails
logger.info(f"Compensation completed for saga {self.execution.saga_id}")
def get_execution_summary(self) -> Dict[str, Any]:
"""Get summary of saga execution"""
return {
"saga_id": self.execution.saga_id,
"status": self.execution.status.value,
"total_steps": len(self.execution.steps),
"current_step": self.execution.current_step,
"completed_steps": len(self._completed_steps),
"started_at": self.execution.started_at.isoformat() if self.execution.started_at else None,
"completed_at": self.execution.completed_at.isoformat() if self.execution.completed_at else None,
"error": str(self.execution.error) if self.execution.error else None,
"steps": [
{
"name": step.name,
"status": step.status.value,
"has_compensation": step.compensation is not None,
"error": str(step.error) if step.error else None
}
for step in self.execution.steps
]
}