""" Unit of Work Pattern Implementation Manages transactions across multiple repositories with event publishing """ from typing import Dict, Any, List, Optional, Type, TypeVar, Generic from contextlib import asynccontextmanager from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.exc import SQLAlchemyError from abc import ABC, abstractmethod import structlog from .repository import BaseRepository from .exceptions import TransactionError logger = structlog.get_logger() Model = TypeVar('Model') Repository = TypeVar('Repository', bound=BaseRepository) class BaseEvent(ABC): """Base class for domain events""" def __init__(self, event_type: str, data: Dict[str, Any]): self.event_type = event_type self.data = data @abstractmethod def to_dict(self) -> Dict[str, Any]: """Convert event to dictionary for publishing""" pass class DomainEvent(BaseEvent): """Standard domain event implementation""" def to_dict(self) -> Dict[str, Any]: return { "event_type": self.event_type, "data": self.data } class UnitOfWork: """ Unit of Work pattern for managing transactions and coordinating repositories Usage: async with UnitOfWork(session) as uow: user_repo = uow.register_repository("users", UserRepository, User) sales_repo = uow.register_repository("sales", SalesRepository, SalesData) user = await user_repo.create(user_data) sale = await sales_repo.create(sales_data) await uow.commit() """ def __init__(self, session: AsyncSession, auto_commit: bool = False): self.session = session self.auto_commit = auto_commit self._repositories: Dict[str, BaseRepository] = {} self._events: List[BaseEvent] = [] self._committed = False self._rolled_back = False def register_repository( self, name: str, repository_class: Type[Repository], model_class: Type[Model], **kwargs ) -> Repository: """ Register a repository with the unit of work Args: name: Unique name for the repository repository_class: Repository class to instantiate model_class: SQLAlchemy model class **kwargs: Additional arguments for repository Returns: Instantiated repository """ if name in self._repositories: logger.warning(f"Repository '{name}' already registered, returning existing instance") return self._repositories[name] repository = repository_class(model_class, self.session, **kwargs) self._repositories[name] = repository logger.debug(f"Registered repository", name=name, model=model_class.__name__) return repository def get_repository(self, name: str) -> Optional[Repository]: """Get registered repository by name""" return self._repositories.get(name) def add_event(self, event: BaseEvent): """Add domain event to be published after commit""" self._events.append(event) logger.debug(f"Added event", event_type=event.event_type) async def commit(self): """Commit the transaction and publish events""" if self._committed: logger.warning("Unit of Work already committed") return if self._rolled_back: raise TransactionError("Cannot commit after rollback") try: await self.session.commit() self._committed = True # Publish events after successful commit await self._publish_events() logger.debug(f"Unit of Work committed successfully", repositories=list(self._repositories.keys()), events_published=len(self._events)) except SQLAlchemyError as e: await self.rollback() logger.error("Failed to commit Unit of Work", error=str(e)) raise TransactionError(f"Commit failed: {str(e)}") async def rollback(self): """Rollback the transaction""" if self._rolled_back: logger.warning("Unit of Work already rolled back") return try: await self.session.rollback() self._rolled_back = True self._events.clear() # Clear events on rollback logger.debug(f"Unit of Work rolled back", repositories=list(self._repositories.keys())) except SQLAlchemyError as e: logger.error("Failed to rollback Unit of Work", error=str(e)) raise TransactionError(f"Rollback failed: {str(e)}") async def _publish_events(self): """Publish domain events (override in subclasses for actual publishing)""" if not self._events: return # Default implementation just logs events # Override this method in service-specific implementations for event in self._events: logger.info(f"Publishing event", event_type=event.event_type, event_data=event.to_dict()) # Clear events after publishing self._events.clear() async def __aenter__(self): """Async context manager entry""" return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit""" if exc_type is not None: # Exception occurred, rollback await self.rollback() return False # No exception, auto-commit if enabled if self.auto_commit and not self._committed: await self.commit() return False class ServiceUnitOfWork(UnitOfWork): """ Service-specific Unit of Work with event publishing integration Example usage with message publishing: class AuthUnitOfWork(ServiceUnitOfWork): def __init__(self, session: AsyncSession, message_publisher=None): super().__init__(session) self.message_publisher = message_publisher async def _publish_events(self): for event in self._events: if self.message_publisher: await self.message_publisher.publish( topic="auth.events", message=event.to_dict() ) """ def __init__(self, session: AsyncSession, event_publisher=None, auto_commit: bool = False): super().__init__(session, auto_commit) self.event_publisher = event_publisher async def _publish_events(self): """Publish events using the provided event publisher""" if not self._events or not self.event_publisher: return try: for event in self._events: await self.event_publisher.publish(event) logger.debug(f"Published event via publisher", event_type=event.event_type) self._events.clear() except Exception as e: logger.error("Failed to publish events", error=str(e)) # Don't raise here to avoid breaking the transaction # Events will be retried or handled by the event publisher # ===== TRANSACTION CONTEXT MANAGER ===== @asynccontextmanager async def transaction_scope(session: AsyncSession, auto_commit: bool = True): """ Simple transaction context manager for single-repository operations Usage: async with transaction_scope(session) as tx_session: user = User(name="John") tx_session.add(user) # Auto-commits on success, rolls back on exception """ try: yield session if auto_commit: await session.commit() except Exception as e: await session.rollback() logger.error("Transaction scope failed", error=str(e)) raise # ===== UTILITIES ===== class RepositoryRegistry: """Registry for commonly used repository configurations""" _registry: Dict[str, Dict[str, Any]] = {} @classmethod def register( self, name: str, repository_class: Type[Repository], model_class: Type[Model], **kwargs ): """Register a repository configuration""" self._registry[name] = { "repository_class": repository_class, "model_class": model_class, "kwargs": kwargs } logger.debug(f"Registered repository configuration", name=name) @classmethod def create_repository(self, name: str, session: AsyncSession) -> Optional[Repository]: """Create repository instance from registry""" config = self._registry.get(name) if not config: logger.warning(f"Repository configuration '{name}' not found in registry") return None return config["repository_class"]( config["model_class"], session, **config["kwargs"] ) @classmethod def list_registered(self) -> List[str]: """List all registered repository names""" return list(self._registry.keys()) # ===== FACTORY FUNCTIONS ===== def create_unit_of_work(session: AsyncSession, **kwargs) -> UnitOfWork: """Factory function to create Unit of Work instances""" return UnitOfWork(session, **kwargs) def create_service_unit_of_work( session: AsyncSession, event_publisher=None, **kwargs ) -> ServiceUnitOfWork: """Factory function to create Service Unit of Work instances""" return ServiceUnitOfWork(session, event_publisher, **kwargs)