304 lines
9.8 KiB
Python
304 lines
9.8 KiB
Python
|
|
"""
|
||
|
|
Unit of Work Pattern Implementation
|
||
|
|
Manages transactions across multiple repositories with event publishing
|
||
|
|
"""
|
||
|
|
|
||
|
|
from typing import Dict, Any, List, Optional, Type, TypeVar, Generic
|
||
|
|
from contextlib import asynccontextmanager
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
from sqlalchemy.exc import SQLAlchemyError
|
||
|
|
from abc import ABC, abstractmethod
|
||
|
|
import structlog
|
||
|
|
|
||
|
|
from .repository import BaseRepository
|
||
|
|
from .exceptions import TransactionError
|
||
|
|
|
||
|
|
logger = structlog.get_logger()
|
||
|
|
|
||
|
|
Model = TypeVar('Model')
|
||
|
|
Repository = TypeVar('Repository', bound=BaseRepository)
|
||
|
|
|
||
|
|
|
||
|
|
class BaseEvent(ABC):
|
||
|
|
"""Base class for domain events"""
|
||
|
|
|
||
|
|
def __init__(self, event_type: str, data: Dict[str, Any]):
|
||
|
|
self.event_type = event_type
|
||
|
|
self.data = data
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def to_dict(self) -> Dict[str, Any]:
|
||
|
|
"""Convert event to dictionary for publishing"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class DomainEvent(BaseEvent):
|
||
|
|
"""Standard domain event implementation"""
|
||
|
|
|
||
|
|
def to_dict(self) -> Dict[str, Any]:
|
||
|
|
return {
|
||
|
|
"event_type": self.event_type,
|
||
|
|
"data": self.data
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
class UnitOfWork:
|
||
|
|
"""
|
||
|
|
Unit of Work pattern for managing transactions and coordinating repositories
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
async with UnitOfWork(session) as uow:
|
||
|
|
user_repo = uow.register_repository("users", UserRepository, User)
|
||
|
|
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
|
||
|
|
|
||
|
|
user = await user_repo.create(user_data)
|
||
|
|
sale = await sales_repo.create(sales_data)
|
||
|
|
|
||
|
|
await uow.commit()
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, session: AsyncSession, auto_commit: bool = False):
|
||
|
|
self.session = session
|
||
|
|
self.auto_commit = auto_commit
|
||
|
|
self._repositories: Dict[str, BaseRepository] = {}
|
||
|
|
self._events: List[BaseEvent] = []
|
||
|
|
self._committed = False
|
||
|
|
self._rolled_back = False
|
||
|
|
|
||
|
|
def register_repository(
|
||
|
|
self,
|
||
|
|
name: str,
|
||
|
|
repository_class: Type[Repository],
|
||
|
|
model_class: Type[Model],
|
||
|
|
**kwargs
|
||
|
|
) -> Repository:
|
||
|
|
"""
|
||
|
|
Register a repository with the unit of work
|
||
|
|
|
||
|
|
Args:
|
||
|
|
name: Unique name for the repository
|
||
|
|
repository_class: Repository class to instantiate
|
||
|
|
model_class: SQLAlchemy model class
|
||
|
|
**kwargs: Additional arguments for repository
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Instantiated repository
|
||
|
|
"""
|
||
|
|
if name in self._repositories:
|
||
|
|
logger.warning(f"Repository '{name}' already registered, returning existing instance")
|
||
|
|
return self._repositories[name]
|
||
|
|
|
||
|
|
repository = repository_class(model_class, self.session, **kwargs)
|
||
|
|
self._repositories[name] = repository
|
||
|
|
|
||
|
|
logger.debug(f"Registered repository", name=name, model=model_class.__name__)
|
||
|
|
return repository
|
||
|
|
|
||
|
|
def get_repository(self, name: str) -> Optional[Repository]:
|
||
|
|
"""Get registered repository by name"""
|
||
|
|
return self._repositories.get(name)
|
||
|
|
|
||
|
|
def add_event(self, event: BaseEvent):
|
||
|
|
"""Add domain event to be published after commit"""
|
||
|
|
self._events.append(event)
|
||
|
|
logger.debug(f"Added event", event_type=event.event_type)
|
||
|
|
|
||
|
|
async def commit(self):
|
||
|
|
"""Commit the transaction and publish events"""
|
||
|
|
if self._committed:
|
||
|
|
logger.warning("Unit of Work already committed")
|
||
|
|
return
|
||
|
|
|
||
|
|
if self._rolled_back:
|
||
|
|
raise TransactionError("Cannot commit after rollback")
|
||
|
|
|
||
|
|
try:
|
||
|
|
await self.session.commit()
|
||
|
|
self._committed = True
|
||
|
|
|
||
|
|
# Publish events after successful commit
|
||
|
|
await self._publish_events()
|
||
|
|
|
||
|
|
logger.debug(f"Unit of Work committed successfully",
|
||
|
|
repositories=list(self._repositories.keys()),
|
||
|
|
events_published=len(self._events))
|
||
|
|
|
||
|
|
except SQLAlchemyError as e:
|
||
|
|
await self.rollback()
|
||
|
|
logger.error("Failed to commit Unit of Work", error=str(e))
|
||
|
|
raise TransactionError(f"Commit failed: {str(e)}")
|
||
|
|
|
||
|
|
async def rollback(self):
|
||
|
|
"""Rollback the transaction"""
|
||
|
|
if self._rolled_back:
|
||
|
|
logger.warning("Unit of Work already rolled back")
|
||
|
|
return
|
||
|
|
|
||
|
|
try:
|
||
|
|
await self.session.rollback()
|
||
|
|
self._rolled_back = True
|
||
|
|
self._events.clear() # Clear events on rollback
|
||
|
|
|
||
|
|
logger.debug(f"Unit of Work rolled back",
|
||
|
|
repositories=list(self._repositories.keys()))
|
||
|
|
|
||
|
|
except SQLAlchemyError as e:
|
||
|
|
logger.error("Failed to rollback Unit of Work", error=str(e))
|
||
|
|
raise TransactionError(f"Rollback failed: {str(e)}")
|
||
|
|
|
||
|
|
async def _publish_events(self):
|
||
|
|
"""Publish domain events (override in subclasses for actual publishing)"""
|
||
|
|
if not self._events:
|
||
|
|
return
|
||
|
|
|
||
|
|
# Default implementation just logs events
|
||
|
|
# Override this method in service-specific implementations
|
||
|
|
for event in self._events:
|
||
|
|
logger.info(f"Publishing event",
|
||
|
|
event_type=event.event_type,
|
||
|
|
event_data=event.to_dict())
|
||
|
|
|
||
|
|
# Clear events after publishing
|
||
|
|
self._events.clear()
|
||
|
|
|
||
|
|
async def __aenter__(self):
|
||
|
|
"""Async context manager entry"""
|
||
|
|
return self
|
||
|
|
|
||
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
|
|
"""Async context manager exit"""
|
||
|
|
if exc_type is not None:
|
||
|
|
# Exception occurred, rollback
|
||
|
|
await self.rollback()
|
||
|
|
return False
|
||
|
|
|
||
|
|
# No exception, auto-commit if enabled
|
||
|
|
if self.auto_commit and not self._committed:
|
||
|
|
await self.commit()
|
||
|
|
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
class ServiceUnitOfWork(UnitOfWork):
|
||
|
|
"""
|
||
|
|
Service-specific Unit of Work with event publishing integration
|
||
|
|
|
||
|
|
Example usage with message publishing:
|
||
|
|
|
||
|
|
class AuthUnitOfWork(ServiceUnitOfWork):
|
||
|
|
def __init__(self, session: AsyncSession, message_publisher=None):
|
||
|
|
super().__init__(session)
|
||
|
|
self.message_publisher = message_publisher
|
||
|
|
|
||
|
|
async def _publish_events(self):
|
||
|
|
for event in self._events:
|
||
|
|
if self.message_publisher:
|
||
|
|
await self.message_publisher.publish(
|
||
|
|
topic="auth.events",
|
||
|
|
message=event.to_dict()
|
||
|
|
)
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, session: AsyncSession, event_publisher=None, auto_commit: bool = False):
|
||
|
|
super().__init__(session, auto_commit)
|
||
|
|
self.event_publisher = event_publisher
|
||
|
|
|
||
|
|
async def _publish_events(self):
|
||
|
|
"""Publish events using the provided event publisher"""
|
||
|
|
if not self._events or not self.event_publisher:
|
||
|
|
return
|
||
|
|
|
||
|
|
try:
|
||
|
|
for event in self._events:
|
||
|
|
await self.event_publisher.publish(event)
|
||
|
|
logger.debug(f"Published event via publisher",
|
||
|
|
event_type=event.event_type)
|
||
|
|
|
||
|
|
self._events.clear()
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("Failed to publish events", error=str(e))
|
||
|
|
# Don't raise here to avoid breaking the transaction
|
||
|
|
# Events will be retried or handled by the event publisher
|
||
|
|
|
||
|
|
|
||
|
|
# ===== TRANSACTION CONTEXT MANAGER =====
|
||
|
|
|
||
|
|
@asynccontextmanager
|
||
|
|
async def transaction_scope(session: AsyncSession, auto_commit: bool = True):
|
||
|
|
"""
|
||
|
|
Simple transaction context manager for single-repository operations
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
async with transaction_scope(session) as tx_session:
|
||
|
|
user = User(name="John")
|
||
|
|
tx_session.add(user)
|
||
|
|
# Auto-commits on success, rolls back on exception
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
yield session
|
||
|
|
if auto_commit:
|
||
|
|
await session.commit()
|
||
|
|
except Exception as e:
|
||
|
|
await session.rollback()
|
||
|
|
logger.error("Transaction scope failed", error=str(e))
|
||
|
|
raise
|
||
|
|
|
||
|
|
|
||
|
|
# ===== UTILITIES =====
|
||
|
|
|
||
|
|
class RepositoryRegistry:
|
||
|
|
"""Registry for commonly used repository configurations"""
|
||
|
|
|
||
|
|
_registry: Dict[str, Dict[str, Any]] = {}
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def register(
|
||
|
|
self,
|
||
|
|
name: str,
|
||
|
|
repository_class: Type[Repository],
|
||
|
|
model_class: Type[Model],
|
||
|
|
**kwargs
|
||
|
|
):
|
||
|
|
"""Register a repository configuration"""
|
||
|
|
self._registry[name] = {
|
||
|
|
"repository_class": repository_class,
|
||
|
|
"model_class": model_class,
|
||
|
|
"kwargs": kwargs
|
||
|
|
}
|
||
|
|
logger.debug(f"Registered repository configuration", name=name)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def create_repository(self, name: str, session: AsyncSession) -> Optional[Repository]:
|
||
|
|
"""Create repository instance from registry"""
|
||
|
|
config = self._registry.get(name)
|
||
|
|
if not config:
|
||
|
|
logger.warning(f"Repository configuration '{name}' not found in registry")
|
||
|
|
return None
|
||
|
|
|
||
|
|
return config["repository_class"](
|
||
|
|
config["model_class"],
|
||
|
|
session,
|
||
|
|
**config["kwargs"]
|
||
|
|
)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def list_registered(self) -> List[str]:
|
||
|
|
"""List all registered repository names"""
|
||
|
|
return list(self._registry.keys())
|
||
|
|
|
||
|
|
|
||
|
|
# ===== FACTORY FUNCTIONS =====
|
||
|
|
|
||
|
|
def create_unit_of_work(session: AsyncSession, **kwargs) -> UnitOfWork:
|
||
|
|
"""Factory function to create Unit of Work instances"""
|
||
|
|
return UnitOfWork(session, **kwargs)
|
||
|
|
|
||
|
|
|
||
|
|
def create_service_unit_of_work(
|
||
|
|
session: AsyncSession,
|
||
|
|
event_publisher=None,
|
||
|
|
**kwargs
|
||
|
|
) -> ServiceUnitOfWork:
|
||
|
|
"""Factory function to create Service Unit of Work instances"""
|
||
|
|
return ServiceUnitOfWork(session, event_publisher, **kwargs)
|