294 lines
9.4 KiB
Python
294 lines
9.4 KiB
Python
|
|
"""
|
||
|
|
Saga Pattern Implementation
|
||
|
|
|
||
|
|
Provides distributed transaction coordination with compensation logic
|
||
|
|
for microservices architecture.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import uuid
|
||
|
|
from typing import Callable, List, Dict, Any, Optional, Tuple
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from datetime import datetime
|
||
|
|
from enum import Enum
|
||
|
|
import logging
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class SagaStepStatus(str, Enum):
|
||
|
|
"""Status of a saga step"""
|
||
|
|
PENDING = "pending"
|
||
|
|
IN_PROGRESS = "in_progress"
|
||
|
|
COMPLETED = "completed"
|
||
|
|
FAILED = "failed"
|
||
|
|
COMPENSATING = "compensating"
|
||
|
|
COMPENSATED = "compensated"
|
||
|
|
|
||
|
|
|
||
|
|
class SagaStatus(str, Enum):
|
||
|
|
"""Overall saga status"""
|
||
|
|
PENDING = "pending"
|
||
|
|
IN_PROGRESS = "in_progress"
|
||
|
|
COMPLETED = "completed"
|
||
|
|
FAILED = "failed"
|
||
|
|
COMPENSATING = "compensating"
|
||
|
|
COMPENSATED = "compensated"
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class SagaStep:
|
||
|
|
"""
|
||
|
|
A single step in a saga with compensation logic.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
name: Human-readable step name
|
||
|
|
action: Async function to execute
|
||
|
|
compensation: Async function to undo the action
|
||
|
|
action_args: Arguments for the action function
|
||
|
|
action_kwargs: Keyword arguments for the action function
|
||
|
|
"""
|
||
|
|
name: str
|
||
|
|
action: Callable
|
||
|
|
compensation: Optional[Callable] = None
|
||
|
|
action_args: tuple = field(default_factory=tuple)
|
||
|
|
action_kwargs: dict = field(default_factory=dict)
|
||
|
|
|
||
|
|
# Runtime state
|
||
|
|
status: SagaStepStatus = SagaStepStatus.PENDING
|
||
|
|
result: Any = None
|
||
|
|
error: Optional[Exception] = None
|
||
|
|
started_at: Optional[datetime] = None
|
||
|
|
completed_at: Optional[datetime] = None
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class SagaExecution:
|
||
|
|
"""Tracks execution state of a saga"""
|
||
|
|
saga_id: str
|
||
|
|
status: SagaStatus = SagaStatus.PENDING
|
||
|
|
steps: List[SagaStep] = field(default_factory=list)
|
||
|
|
current_step: int = 0
|
||
|
|
started_at: Optional[datetime] = None
|
||
|
|
completed_at: Optional[datetime] = None
|
||
|
|
error: Optional[Exception] = None
|
||
|
|
|
||
|
|
|
||
|
|
class SagaCoordinator:
|
||
|
|
"""
|
||
|
|
Coordinates saga execution with automatic compensation on failure.
|
||
|
|
|
||
|
|
Example:
|
||
|
|
```python
|
||
|
|
saga = SagaCoordinator()
|
||
|
|
|
||
|
|
saga.add_step(
|
||
|
|
"create_order",
|
||
|
|
action=create_order,
|
||
|
|
compensation=delete_order,
|
||
|
|
action_args=(order_data,)
|
||
|
|
)
|
||
|
|
|
||
|
|
saga.add_step(
|
||
|
|
"reserve_inventory",
|
||
|
|
action=reserve_inventory,
|
||
|
|
compensation=release_inventory,
|
||
|
|
action_args=(order_id, items)
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await saga.execute()
|
||
|
|
```
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, saga_id: Optional[str] = None):
|
||
|
|
self.execution = SagaExecution(
|
||
|
|
saga_id=saga_id or str(uuid.uuid4())
|
||
|
|
)
|
||
|
|
self._completed_steps: List[SagaStep] = []
|
||
|
|
|
||
|
|
def add_step(
|
||
|
|
self,
|
||
|
|
name: str,
|
||
|
|
action: Callable,
|
||
|
|
compensation: Optional[Callable] = None,
|
||
|
|
action_args: tuple = (),
|
||
|
|
action_kwargs: dict = None
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Add a step to the saga.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
name: Human-readable step name
|
||
|
|
action: Async function to execute
|
||
|
|
compensation: Async function to undo the action (optional)
|
||
|
|
action_args: Arguments for the action function
|
||
|
|
action_kwargs: Keyword arguments for the action function
|
||
|
|
"""
|
||
|
|
step = SagaStep(
|
||
|
|
name=name,
|
||
|
|
action=action,
|
||
|
|
compensation=compensation,
|
||
|
|
action_args=action_args,
|
||
|
|
action_kwargs=action_kwargs or {}
|
||
|
|
)
|
||
|
|
self.execution.steps.append(step)
|
||
|
|
logger.debug(f"Added step '{name}' to saga {self.execution.saga_id}")
|
||
|
|
|
||
|
|
async def execute(self) -> Tuple[bool, Optional[Any], Optional[Exception]]:
|
||
|
|
"""
|
||
|
|
Execute all saga steps in sequence.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple of (success: bool, final_result: Any, error: Optional[Exception])
|
||
|
|
"""
|
||
|
|
self.execution.status = SagaStatus.IN_PROGRESS
|
||
|
|
self.execution.started_at = datetime.now()
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
f"Starting saga {self.execution.saga_id} with {len(self.execution.steps)} steps"
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Execute each step
|
||
|
|
for idx, step in enumerate(self.execution.steps):
|
||
|
|
self.execution.current_step = idx
|
||
|
|
|
||
|
|
success = await self._execute_step(step)
|
||
|
|
|
||
|
|
if not success:
|
||
|
|
# Step failed, trigger compensation
|
||
|
|
logger.error(
|
||
|
|
f"Saga {self.execution.saga_id} failed at step '{step.name}': {step.error}"
|
||
|
|
)
|
||
|
|
await self._compensate()
|
||
|
|
|
||
|
|
self.execution.status = SagaStatus.COMPENSATED
|
||
|
|
self.execution.completed_at = datetime.now()
|
||
|
|
self.execution.error = step.error
|
||
|
|
|
||
|
|
return False, None, step.error
|
||
|
|
|
||
|
|
# Step succeeded
|
||
|
|
self._completed_steps.append(step)
|
||
|
|
|
||
|
|
# All steps completed successfully
|
||
|
|
self.execution.status = SagaStatus.COMPLETED
|
||
|
|
self.execution.completed_at = datetime.now()
|
||
|
|
|
||
|
|
# Return the result of the last step
|
||
|
|
final_result = self.execution.steps[-1].result if self.execution.steps else None
|
||
|
|
|
||
|
|
logger.info(f"Saga {self.execution.saga_id} completed successfully")
|
||
|
|
return True, final_result, None
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.exception(f"Unexpected error in saga {self.execution.saga_id}: {e}")
|
||
|
|
await self._compensate()
|
||
|
|
|
||
|
|
self.execution.status = SagaStatus.FAILED
|
||
|
|
self.execution.completed_at = datetime.now()
|
||
|
|
self.execution.error = e
|
||
|
|
|
||
|
|
return False, None, e
|
||
|
|
|
||
|
|
async def _execute_step(self, step: SagaStep) -> bool:
|
||
|
|
"""
|
||
|
|
Execute a single saga step.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if step succeeded, False otherwise
|
||
|
|
"""
|
||
|
|
step.status = SagaStepStatus.IN_PROGRESS
|
||
|
|
step.started_at = datetime.now()
|
||
|
|
|
||
|
|
logger.info(f"Executing saga step '{step.name}'")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Execute the action
|
||
|
|
if asyncio.iscoroutinefunction(step.action):
|
||
|
|
result = await step.action(*step.action_args, **step.action_kwargs)
|
||
|
|
else:
|
||
|
|
result = step.action(*step.action_args, **step.action_kwargs)
|
||
|
|
|
||
|
|
step.result = result
|
||
|
|
step.status = SagaStepStatus.COMPLETED
|
||
|
|
step.completed_at = datetime.now()
|
||
|
|
|
||
|
|
logger.info(f"Saga step '{step.name}' completed successfully")
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
step.error = e
|
||
|
|
step.status = SagaStepStatus.FAILED
|
||
|
|
step.completed_at = datetime.now()
|
||
|
|
|
||
|
|
logger.error(f"Saga step '{step.name}' failed: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def _compensate(self):
|
||
|
|
"""
|
||
|
|
Execute compensation logic for all completed steps in reverse order.
|
||
|
|
"""
|
||
|
|
if not self._completed_steps:
|
||
|
|
logger.info(f"No steps to compensate for saga {self.execution.saga_id}")
|
||
|
|
return
|
||
|
|
|
||
|
|
self.execution.status = SagaStatus.COMPENSATING
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
f"Starting compensation for saga {self.execution.saga_id} "
|
||
|
|
f"({len(self._completed_steps)} steps to compensate)"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Compensate in reverse order
|
||
|
|
for step in reversed(self._completed_steps):
|
||
|
|
if step.compensation is None:
|
||
|
|
logger.warning(
|
||
|
|
f"Step '{step.name}' has no compensation function, skipping"
|
||
|
|
)
|
||
|
|
continue
|
||
|
|
|
||
|
|
step.status = SagaStepStatus.COMPENSATING
|
||
|
|
|
||
|
|
try:
|
||
|
|
logger.info(f"Compensating step '{step.name}'")
|
||
|
|
|
||
|
|
# Execute compensation with the result from the original action
|
||
|
|
compensation_args = (step.result,) if step.result is not None else ()
|
||
|
|
|
||
|
|
if asyncio.iscoroutinefunction(step.compensation):
|
||
|
|
await step.compensation(*compensation_args)
|
||
|
|
else:
|
||
|
|
step.compensation(*compensation_args)
|
||
|
|
|
||
|
|
step.status = SagaStepStatus.COMPENSATED
|
||
|
|
logger.info(f"Step '{step.name}' compensated successfully")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to compensate step '{step.name}': {e}")
|
||
|
|
# Continue compensating other steps even if one fails
|
||
|
|
|
||
|
|
logger.info(f"Compensation completed for saga {self.execution.saga_id}")
|
||
|
|
|
||
|
|
def get_execution_summary(self) -> Dict[str, Any]:
|
||
|
|
"""Get summary of saga execution"""
|
||
|
|
return {
|
||
|
|
"saga_id": self.execution.saga_id,
|
||
|
|
"status": self.execution.status.value,
|
||
|
|
"total_steps": len(self.execution.steps),
|
||
|
|
"current_step": self.execution.current_step,
|
||
|
|
"completed_steps": len(self._completed_steps),
|
||
|
|
"started_at": self.execution.started_at.isoformat() if self.execution.started_at else None,
|
||
|
|
"completed_at": self.execution.completed_at.isoformat() if self.execution.completed_at else None,
|
||
|
|
"error": str(self.execution.error) if self.execution.error else None,
|
||
|
|
"steps": [
|
||
|
|
{
|
||
|
|
"name": step.name,
|
||
|
|
"status": step.status.value,
|
||
|
|
"has_compensation": step.compensation is not None,
|
||
|
|
"error": str(step.error) if step.error else None
|
||
|
|
}
|
||
|
|
for step in self.execution.steps
|
||
|
|
]
|
||
|
|
}
|