Files
bakery-ia/shared/database/unit_of_work.py

304 lines
9.8 KiB
Python
Raw Normal View History

2026-01-21 17:17:16 +01:00
"""
Unit of Work Pattern Implementation
Manages transactions across multiple repositories with event publishing
"""
from typing import Dict, Any, List, Optional, Type, TypeVar, Generic
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
from abc import ABC, abstractmethod
import structlog
from .repository import BaseRepository
from .exceptions import TransactionError
logger = structlog.get_logger()
Model = TypeVar('Model')
Repository = TypeVar('Repository', bound=BaseRepository)
class BaseEvent(ABC):
"""Base class for domain events"""
def __init__(self, event_type: str, data: Dict[str, Any]):
self.event_type = event_type
self.data = data
@abstractmethod
def to_dict(self) -> Dict[str, Any]:
"""Convert event to dictionary for publishing"""
pass
class DomainEvent(BaseEvent):
"""Standard domain event implementation"""
def to_dict(self) -> Dict[str, Any]:
return {
"event_type": self.event_type,
"data": self.data
}
class UnitOfWork:
"""
Unit of Work pattern for managing transactions and coordinating repositories
Usage:
async with UnitOfWork(session) as uow:
user_repo = uow.register_repository("users", UserRepository, User)
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
user = await user_repo.create(user_data)
sale = await sales_repo.create(sales_data)
await uow.commit()
"""
def __init__(self, session: AsyncSession, auto_commit: bool = False):
self.session = session
self.auto_commit = auto_commit
self._repositories: Dict[str, BaseRepository] = {}
self._events: List[BaseEvent] = []
self._committed = False
self._rolled_back = False
def register_repository(
self,
name: str,
repository_class: Type[Repository],
model_class: Type[Model],
**kwargs
) -> Repository:
"""
Register a repository with the unit of work
Args:
name: Unique name for the repository
repository_class: Repository class to instantiate
model_class: SQLAlchemy model class
**kwargs: Additional arguments for repository
Returns:
Instantiated repository
"""
if name in self._repositories:
logger.warning(f"Repository '{name}' already registered, returning existing instance")
return self._repositories[name]
repository = repository_class(model_class, self.session, **kwargs)
self._repositories[name] = repository
logger.debug(f"Registered repository", name=name, model=model_class.__name__)
return repository
def get_repository(self, name: str) -> Optional[Repository]:
"""Get registered repository by name"""
return self._repositories.get(name)
def add_event(self, event: BaseEvent):
"""Add domain event to be published after commit"""
self._events.append(event)
logger.debug(f"Added event", event_type=event.event_type)
async def commit(self):
"""Commit the transaction and publish events"""
if self._committed:
logger.warning("Unit of Work already committed")
return
if self._rolled_back:
raise TransactionError("Cannot commit after rollback")
try:
await self.session.commit()
self._committed = True
# Publish events after successful commit
await self._publish_events()
logger.debug(f"Unit of Work committed successfully",
repositories=list(self._repositories.keys()),
events_published=len(self._events))
except SQLAlchemyError as e:
await self.rollback()
logger.error("Failed to commit Unit of Work", error=str(e))
raise TransactionError(f"Commit failed: {str(e)}")
async def rollback(self):
"""Rollback the transaction"""
if self._rolled_back:
logger.warning("Unit of Work already rolled back")
return
try:
await self.session.rollback()
self._rolled_back = True
self._events.clear() # Clear events on rollback
logger.debug(f"Unit of Work rolled back",
repositories=list(self._repositories.keys()))
except SQLAlchemyError as e:
logger.error("Failed to rollback Unit of Work", error=str(e))
raise TransactionError(f"Rollback failed: {str(e)}")
async def _publish_events(self):
"""Publish domain events (override in subclasses for actual publishing)"""
if not self._events:
return
# Default implementation just logs events
# Override this method in service-specific implementations
for event in self._events:
logger.info(f"Publishing event",
event_type=event.event_type,
event_data=event.to_dict())
# Clear events after publishing
self._events.clear()
async def __aenter__(self):
"""Async context manager entry"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
if exc_type is not None:
# Exception occurred, rollback
await self.rollback()
return False
# No exception, auto-commit if enabled
if self.auto_commit and not self._committed:
await self.commit()
return False
class ServiceUnitOfWork(UnitOfWork):
"""
Service-specific Unit of Work with event publishing integration
Example usage with message publishing:
class AuthUnitOfWork(ServiceUnitOfWork):
def __init__(self, session: AsyncSession, message_publisher=None):
super().__init__(session)
self.message_publisher = message_publisher
async def _publish_events(self):
for event in self._events:
if self.message_publisher:
await self.message_publisher.publish(
topic="auth.events",
message=event.to_dict()
)
"""
def __init__(self, session: AsyncSession, event_publisher=None, auto_commit: bool = False):
super().__init__(session, auto_commit)
self.event_publisher = event_publisher
async def _publish_events(self):
"""Publish events using the provided event publisher"""
if not self._events or not self.event_publisher:
return
try:
for event in self._events:
await self.event_publisher.publish(event)
logger.debug(f"Published event via publisher",
event_type=event.event_type)
self._events.clear()
except Exception as e:
logger.error("Failed to publish events", error=str(e))
# Don't raise here to avoid breaking the transaction
# Events will be retried or handled by the event publisher
# ===== TRANSACTION CONTEXT MANAGER =====
@asynccontextmanager
async def transaction_scope(session: AsyncSession, auto_commit: bool = True):
"""
Simple transaction context manager for single-repository operations
Usage:
async with transaction_scope(session) as tx_session:
user = User(name="John")
tx_session.add(user)
# Auto-commits on success, rolls back on exception
"""
try:
yield session
if auto_commit:
await session.commit()
except Exception as e:
await session.rollback()
logger.error("Transaction scope failed", error=str(e))
raise
# ===== UTILITIES =====
class RepositoryRegistry:
"""Registry for commonly used repository configurations"""
_registry: Dict[str, Dict[str, Any]] = {}
@classmethod
def register(
self,
name: str,
repository_class: Type[Repository],
model_class: Type[Model],
**kwargs
):
"""Register a repository configuration"""
self._registry[name] = {
"repository_class": repository_class,
"model_class": model_class,
"kwargs": kwargs
}
logger.debug(f"Registered repository configuration", name=name)
@classmethod
def create_repository(self, name: str, session: AsyncSession) -> Optional[Repository]:
"""Create repository instance from registry"""
config = self._registry.get(name)
if not config:
logger.warning(f"Repository configuration '{name}' not found in registry")
return None
return config["repository_class"](
config["model_class"],
session,
**config["kwargs"]
)
@classmethod
def list_registered(self) -> List[str]:
"""List all registered repository names"""
return list(self._registry.keys())
# ===== FACTORY FUNCTIONS =====
def create_unit_of_work(session: AsyncSession, **kwargs) -> UnitOfWork:
"""Factory function to create Unit of Work instances"""
return UnitOfWork(session, **kwargs)
def create_service_unit_of_work(
session: AsyncSession,
event_publisher=None,
**kwargs
) -> ServiceUnitOfWork:
"""Factory function to create Service Unit of Work instances"""
return ServiceUnitOfWork(session, event_publisher, **kwargs)