REFACTOR - Database logic
This commit is contained in:
304
shared/database/unit_of_work.py
Normal file
304
shared/database/unit_of_work.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Unit of Work Pattern Implementation
|
||||
Manages transactions across multiple repositories with event publishing
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Type, TypeVar, Generic
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from abc import ABC, abstractmethod
|
||||
import structlog
|
||||
|
||||
from .repository import BaseRepository
|
||||
from .exceptions import TransactionError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
Model = TypeVar('Model')
|
||||
Repository = TypeVar('Repository', bound=BaseRepository)
|
||||
|
||||
|
||||
class BaseEvent(ABC):
|
||||
"""Base class for domain events"""
|
||||
|
||||
def __init__(self, event_type: str, data: Dict[str, Any]):
|
||||
self.event_type = event_type
|
||||
self.data = data
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert event to dictionary for publishing"""
|
||||
pass
|
||||
|
||||
|
||||
class DomainEvent(BaseEvent):
|
||||
"""Standard domain event implementation"""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"event_type": self.event_type,
|
||||
"data": self.data
|
||||
}
|
||||
|
||||
|
||||
class UnitOfWork:
|
||||
"""
|
||||
Unit of Work pattern for managing transactions and coordinating repositories
|
||||
|
||||
Usage:
|
||||
async with UnitOfWork(session) as uow:
|
||||
user_repo = uow.register_repository("users", UserRepository, User)
|
||||
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
|
||||
|
||||
user = await user_repo.create(user_data)
|
||||
sale = await sales_repo.create(sales_data)
|
||||
|
||||
await uow.commit()
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession, auto_commit: bool = False):
|
||||
self.session = session
|
||||
self.auto_commit = auto_commit
|
||||
self._repositories: Dict[str, BaseRepository] = {}
|
||||
self._events: List[BaseEvent] = []
|
||||
self._committed = False
|
||||
self._rolled_back = False
|
||||
|
||||
def register_repository(
|
||||
self,
|
||||
name: str,
|
||||
repository_class: Type[Repository],
|
||||
model_class: Type[Model],
|
||||
**kwargs
|
||||
) -> Repository:
|
||||
"""
|
||||
Register a repository with the unit of work
|
||||
|
||||
Args:
|
||||
name: Unique name for the repository
|
||||
repository_class: Repository class to instantiate
|
||||
model_class: SQLAlchemy model class
|
||||
**kwargs: Additional arguments for repository
|
||||
|
||||
Returns:
|
||||
Instantiated repository
|
||||
"""
|
||||
if name in self._repositories:
|
||||
logger.warning(f"Repository '{name}' already registered, returning existing instance")
|
||||
return self._repositories[name]
|
||||
|
||||
repository = repository_class(model_class, self.session, **kwargs)
|
||||
self._repositories[name] = repository
|
||||
|
||||
logger.debug(f"Registered repository", name=name, model=model_class.__name__)
|
||||
return repository
|
||||
|
||||
def get_repository(self, name: str) -> Optional[Repository]:
|
||||
"""Get registered repository by name"""
|
||||
return self._repositories.get(name)
|
||||
|
||||
def add_event(self, event: BaseEvent):
|
||||
"""Add domain event to be published after commit"""
|
||||
self._events.append(event)
|
||||
logger.debug(f"Added event", event_type=event.event_type)
|
||||
|
||||
async def commit(self):
|
||||
"""Commit the transaction and publish events"""
|
||||
if self._committed:
|
||||
logger.warning("Unit of Work already committed")
|
||||
return
|
||||
|
||||
if self._rolled_back:
|
||||
raise TransactionError("Cannot commit after rollback")
|
||||
|
||||
try:
|
||||
await self.session.commit()
|
||||
self._committed = True
|
||||
|
||||
# Publish events after successful commit
|
||||
await self._publish_events()
|
||||
|
||||
logger.debug(f"Unit of Work committed successfully",
|
||||
repositories=list(self._repositories.keys()),
|
||||
events_published=len(self._events))
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
await self.rollback()
|
||||
logger.error("Failed to commit Unit of Work", error=str(e))
|
||||
raise TransactionError(f"Commit failed: {str(e)}")
|
||||
|
||||
async def rollback(self):
|
||||
"""Rollback the transaction"""
|
||||
if self._rolled_back:
|
||||
logger.warning("Unit of Work already rolled back")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.session.rollback()
|
||||
self._rolled_back = True
|
||||
self._events.clear() # Clear events on rollback
|
||||
|
||||
logger.debug(f"Unit of Work rolled back",
|
||||
repositories=list(self._repositories.keys()))
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error("Failed to rollback Unit of Work", error=str(e))
|
||||
raise TransactionError(f"Rollback failed: {str(e)}")
|
||||
|
||||
async def _publish_events(self):
|
||||
"""Publish domain events (override in subclasses for actual publishing)"""
|
||||
if not self._events:
|
||||
return
|
||||
|
||||
# Default implementation just logs events
|
||||
# Override this method in service-specific implementations
|
||||
for event in self._events:
|
||||
logger.info(f"Publishing event",
|
||||
event_type=event.event_type,
|
||||
event_data=event.to_dict())
|
||||
|
||||
# Clear events after publishing
|
||||
self._events.clear()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit"""
|
||||
if exc_type is not None:
|
||||
# Exception occurred, rollback
|
||||
await self.rollback()
|
||||
return False
|
||||
|
||||
# No exception, auto-commit if enabled
|
||||
if self.auto_commit and not self._committed:
|
||||
await self.commit()
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class ServiceUnitOfWork(UnitOfWork):
|
||||
"""
|
||||
Service-specific Unit of Work with event publishing integration
|
||||
|
||||
Example usage with message publishing:
|
||||
|
||||
class AuthUnitOfWork(ServiceUnitOfWork):
|
||||
def __init__(self, session: AsyncSession, message_publisher=None):
|
||||
super().__init__(session)
|
||||
self.message_publisher = message_publisher
|
||||
|
||||
async def _publish_events(self):
|
||||
for event in self._events:
|
||||
if self.message_publisher:
|
||||
await self.message_publisher.publish(
|
||||
topic="auth.events",
|
||||
message=event.to_dict()
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession, event_publisher=None, auto_commit: bool = False):
|
||||
super().__init__(session, auto_commit)
|
||||
self.event_publisher = event_publisher
|
||||
|
||||
async def _publish_events(self):
|
||||
"""Publish events using the provided event publisher"""
|
||||
if not self._events or not self.event_publisher:
|
||||
return
|
||||
|
||||
try:
|
||||
for event in self._events:
|
||||
await self.event_publisher.publish(event)
|
||||
logger.debug(f"Published event via publisher",
|
||||
event_type=event.event_type)
|
||||
|
||||
self._events.clear()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to publish events", error=str(e))
|
||||
# Don't raise here to avoid breaking the transaction
|
||||
# Events will be retried or handled by the event publisher
|
||||
|
||||
|
||||
# ===== TRANSACTION CONTEXT MANAGER =====
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction_scope(session: AsyncSession, auto_commit: bool = True):
|
||||
"""
|
||||
Simple transaction context manager for single-repository operations
|
||||
|
||||
Usage:
|
||||
async with transaction_scope(session) as tx_session:
|
||||
user = User(name="John")
|
||||
tx_session.add(user)
|
||||
# Auto-commits on success, rolls back on exception
|
||||
"""
|
||||
try:
|
||||
yield session
|
||||
if auto_commit:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error("Transaction scope failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
# ===== UTILITIES =====
|
||||
|
||||
class RepositoryRegistry:
|
||||
"""Registry for commonly used repository configurations"""
|
||||
|
||||
_registry: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
repository_class: Type[Repository],
|
||||
model_class: Type[Model],
|
||||
**kwargs
|
||||
):
|
||||
"""Register a repository configuration"""
|
||||
self._registry[name] = {
|
||||
"repository_class": repository_class,
|
||||
"model_class": model_class,
|
||||
"kwargs": kwargs
|
||||
}
|
||||
logger.debug(f"Registered repository configuration", name=name)
|
||||
|
||||
@classmethod
|
||||
def create_repository(self, name: str, session: AsyncSession) -> Optional[Repository]:
|
||||
"""Create repository instance from registry"""
|
||||
config = self._registry.get(name)
|
||||
if not config:
|
||||
logger.warning(f"Repository configuration '{name}' not found in registry")
|
||||
return None
|
||||
|
||||
return config["repository_class"](
|
||||
config["model_class"],
|
||||
session,
|
||||
**config["kwargs"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_registered(self) -> List[str]:
|
||||
"""List all registered repository names"""
|
||||
return list(self._registry.keys())
|
||||
|
||||
|
||||
# ===== FACTORY FUNCTIONS =====
|
||||
|
||||
def create_unit_of_work(session: AsyncSession, **kwargs) -> UnitOfWork:
|
||||
"""Factory function to create Unit of Work instances"""
|
||||
return UnitOfWork(session, **kwargs)
|
||||
|
||||
|
||||
def create_service_unit_of_work(
|
||||
session: AsyncSession,
|
||||
event_publisher=None,
|
||||
**kwargs
|
||||
) -> ServiceUnitOfWork:
|
||||
"""Factory function to create Service Unit of Work instances"""
|
||||
return ServiceUnitOfWork(session, event_publisher, **kwargs)
|
||||
Reference in New Issue
Block a user