Files
bakery-ia/shared/database/repository.py

422 lines
16 KiB
Python
Raw Normal View History

2025-08-08 09:08:41 +02:00
"""
Base Repository Pattern for Database Operations
Provides generic CRUD operations, query building, and caching
"""
from typing import Optional, List, Dict, Any, TypeVar, Generic, Type, Union
from abc import ABC, abstractmethod
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import declarative_base
from sqlalchemy import select, update, delete, and_, or_, desc, asc, func, text
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from contextlib import asynccontextmanager
import structlog
from .exceptions import (
DatabaseError,
RecordNotFoundError,
DuplicateRecordError,
ConstraintViolationError
)
logger = structlog.get_logger()
# Type variables for generic repository
Model = TypeVar('Model', bound=declarative_base())
CreateSchema = TypeVar('CreateSchema')
UpdateSchema = TypeVar('UpdateSchema')
class BaseRepository(Generic[Model, CreateSchema, UpdateSchema], ABC):
"""
Base repository providing generic CRUD operations
Args:
model: SQLAlchemy model class
session: Database session
cache_ttl: Cache time-to-live in seconds (optional)
"""
def __init__(self, model: Type[Model], session: AsyncSession, cache_ttl: Optional[int] = None):
self.model = model
self.session = session
self.cache_ttl = cache_ttl
self._cache = {} if cache_ttl else None
# ===== CORE CRUD OPERATIONS =====
async def create(self, obj_in: CreateSchema, **kwargs) -> Model:
"""Create a new record"""
try:
# Convert schema to dict if needed
if hasattr(obj_in, 'model_dump'):
obj_data = obj_in.model_dump()
elif hasattr(obj_in, 'dict'):
obj_data = obj_in.dict()
else:
obj_data = obj_in
# Merge with additional kwargs
obj_data.update(kwargs)
db_obj = self.model(**obj_data)
self.session.add(db_obj)
await self.session.flush() # Get ID without committing
await self.session.refresh(db_obj)
logger.debug(f"Created {self.model.__name__}", record_id=getattr(db_obj, 'id', None))
return db_obj
except IntegrityError as e:
await self.session.rollback()
logger.error(f"Integrity error creating {self.model.__name__}", error=str(e))
raise DuplicateRecordError(f"Record with provided data already exists")
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error creating {self.model.__name__}", error=str(e))
raise DatabaseError(f"Failed to create record: {str(e)}")
async def get_by_id(self, record_id: Any) -> Optional[Model]:
"""Get record by ID with optional caching"""
cache_key = f"{self.model.__name__}:{record_id}"
# Check cache first
if self._cache and cache_key in self._cache:
logger.debug(f"Cache hit for {cache_key}")
return self._cache[cache_key]
try:
result = await self.session.execute(
select(self.model).where(self.model.id == record_id)
)
record = result.scalar_one_or_none()
# Cache the result
if self._cache and record:
self._cache[cache_key] = record
return record
except SQLAlchemyError as e:
logger.error(f"Database error getting {self.model.__name__} by ID",
record_id=record_id, error=str(e))
raise DatabaseError(f"Failed to get record: {str(e)}")
async def get_by_field(self, field_name: str, value: Any) -> Optional[Model]:
"""Get record by specific field"""
try:
result = await self.session.execute(
select(self.model).where(getattr(self.model, field_name) == value)
)
return result.scalar_one_or_none()
except AttributeError:
raise ValueError(f"Field '{field_name}' not found in {self.model.__name__}")
except SQLAlchemyError as e:
logger.error(f"Database error getting {self.model.__name__} by {field_name}",
value=value, error=str(e))
raise DatabaseError(f"Failed to get record: {str(e)}")
async def get_multi(
self,
skip: int = 0,
limit: int = 100,
order_by: Optional[str] = None,
order_desc: bool = False,
filters: Optional[Dict[str, Any]] = None
) -> List[Model]:
"""Get multiple records with pagination, sorting, and filtering"""
try:
query = select(self.model)
# Apply filters
if filters:
conditions = []
for field, value in filters.items():
if hasattr(self.model, field):
if isinstance(value, list):
conditions.append(getattr(self.model, field).in_(value))
else:
conditions.append(getattr(self.model, field) == value)
if conditions:
query = query.where(and_(*conditions))
# Apply ordering
if order_by and hasattr(self.model, order_by):
order_field = getattr(self.model, order_by)
if order_desc:
query = query.order_by(desc(order_field))
else:
query = query.order_by(asc(order_field))
# Apply pagination
query = query.offset(skip).limit(limit)
result = await self.session.execute(query)
return result.scalars().all()
except SQLAlchemyError as e:
logger.error(f"Database error getting multiple {self.model.__name__} records",
error=str(e))
raise DatabaseError(f"Failed to get records: {str(e)}")
async def update(self, record_id: Any, obj_in: UpdateSchema, **kwargs) -> Optional[Model]:
"""Update record by ID"""
try:
# Convert schema to dict if needed
if hasattr(obj_in, 'model_dump'):
update_data = obj_in.model_dump(exclude_unset=True)
elif hasattr(obj_in, 'dict'):
update_data = obj_in.dict(exclude_unset=True)
else:
update_data = obj_in
# Merge with additional kwargs
update_data.update(kwargs)
# Remove None values
update_data = {k: v for k, v in update_data.items() if v is not None}
if not update_data:
logger.warning(f"No data to update for {self.model.__name__}", record_id=record_id)
return await self.get_by_id(record_id)
# Perform update
result = await self.session.execute(
update(self.model)
.where(self.model.id == record_id)
.values(**update_data)
.returning(self.model)
)
updated_record = result.scalar_one_or_none()
if not updated_record:
raise RecordNotFoundError(f"{self.model.__name__} with id {record_id} not found")
# Clear cache
if self._cache:
cache_key = f"{self.model.__name__}:{record_id}"
self._cache.pop(cache_key, None)
logger.debug(f"Updated {self.model.__name__}", record_id=record_id)
return updated_record
except IntegrityError as e:
await self.session.rollback()
logger.error(f"Integrity error updating {self.model.__name__}",
record_id=record_id, error=str(e))
raise ConstraintViolationError(f"Update violates database constraints")
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error updating {self.model.__name__}",
record_id=record_id, error=str(e))
raise DatabaseError(f"Failed to update record: {str(e)}")
async def delete(self, record_id: Any) -> bool:
"""Delete record by ID"""
try:
result = await self.session.execute(
delete(self.model).where(self.model.id == record_id)
)
deleted_count = result.rowcount
if deleted_count == 0:
raise RecordNotFoundError(f"{self.model.__name__} with id {record_id} not found")
# Clear cache
if self._cache:
cache_key = f"{self.model.__name__}:{record_id}"
self._cache.pop(cache_key, None)
logger.debug(f"Deleted {self.model.__name__}", record_id=record_id)
return True
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error deleting {self.model.__name__}",
record_id=record_id, error=str(e))
raise DatabaseError(f"Failed to delete record: {str(e)}")
# ===== ADVANCED QUERY OPERATIONS =====
async def count(self, filters: Optional[Dict[str, Any]] = None) -> int:
"""Count records with optional filters"""
try:
query = select(func.count(self.model.id))
if filters:
conditions = []
for field, value in filters.items():
if hasattr(self.model, field):
if isinstance(value, list):
conditions.append(getattr(self.model, field).in_(value))
else:
conditions.append(getattr(self.model, field) == value)
if conditions:
query = query.where(and_(*conditions))
result = await self.session.execute(query)
return result.scalar() or 0
except SQLAlchemyError as e:
logger.error(f"Database error counting {self.model.__name__} records", error=str(e))
raise DatabaseError(f"Failed to count records: {str(e)}")
async def exists(self, record_id: Any) -> bool:
"""Check if record exists by ID"""
try:
result = await self.session.execute(
select(func.count(self.model.id)).where(self.model.id == record_id)
)
count = result.scalar() or 0
return count > 0
except SQLAlchemyError as e:
logger.error(f"Database error checking existence of {self.model.__name__}",
record_id=record_id, error=str(e))
raise DatabaseError(f"Failed to check record existence: {str(e)}")
async def bulk_create(self, objects: List[CreateSchema]) -> List[Model]:
"""Create multiple records in bulk"""
try:
if not objects:
return []
db_objects = []
for obj_in in objects:
if hasattr(obj_in, 'model_dump'):
obj_data = obj_in.model_dump()
elif hasattr(obj_in, 'dict'):
obj_data = obj_in.dict()
else:
obj_data = obj_in
db_objects.append(self.model(**obj_data))
self.session.add_all(db_objects)
await self.session.flush()
for db_obj in db_objects:
await self.session.refresh(db_obj)
logger.debug(f"Bulk created {len(db_objects)} {self.model.__name__} records")
return db_objects
except IntegrityError as e:
await self.session.rollback()
logger.error(f"Integrity error bulk creating {self.model.__name__}", error=str(e))
raise DuplicateRecordError(f"One or more records already exist")
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error bulk creating {self.model.__name__}", error=str(e))
raise DatabaseError(f"Failed to create records: {str(e)}")
async def bulk_update(self, updates: List[Dict[str, Any]]) -> int:
"""Update multiple records in bulk"""
try:
if not updates:
return 0
# Group updates by fields being updated for efficiency
for update_data in updates:
if 'id' not in update_data:
raise ValueError("Each update must include 'id' field")
record_id = update_data.pop('id')
await self.session.execute(
update(self.model)
.where(self.model.id == record_id)
.values(**update_data)
)
# Clear relevant cache entries
if self._cache:
for update_data in updates:
record_id = update_data.get('id')
if record_id:
cache_key = f"{self.model.__name__}:{record_id}"
self._cache.pop(cache_key, None)
logger.debug(f"Bulk updated {len(updates)} {self.model.__name__} records")
return len(updates)
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error bulk updating {self.model.__name__}", error=str(e))
raise DatabaseError(f"Failed to update records: {str(e)}")
# ===== SEARCH AND QUERY BUILDING =====
async def search(
self,
search_term: str,
search_fields: List[str],
skip: int = 0,
limit: int = 100
) -> List[Model]:
"""Search records across multiple fields"""
try:
conditions = []
for field in search_fields:
if hasattr(self.model, field):
field_obj = getattr(self.model, field)
# Case-insensitive partial match
conditions.append(field_obj.ilike(f"%{search_term}%"))
if not conditions:
logger.warning(f"No valid search fields provided for {self.model.__name__}")
return []
query = select(self.model).where(or_(*conditions)).offset(skip).limit(limit)
result = await self.session.execute(query)
return result.scalars().all()
except SQLAlchemyError as e:
logger.error(f"Database error searching {self.model.__name__}",
search_term=search_term, error=str(e))
raise DatabaseError(f"Failed to search records: {str(e)}")
async def execute_raw_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> Any:
"""Execute raw SQL query (use with caution)"""
try:
result = await self.session.execute(text(query), params or {})
return result
except SQLAlchemyError as e:
logger.error(f"Database error executing raw query", query=query, error=str(e))
raise DatabaseError(f"Failed to execute query: {str(e)}")
# ===== CACHE MANAGEMENT =====
def clear_cache(self, record_id: Optional[Any] = None):
"""Clear cache for specific record or all records"""
if not self._cache:
return
if record_id:
cache_key = f"{self.model.__name__}:{record_id}"
self._cache.pop(cache_key, None)
else:
# Clear all cache entries for this model
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{self.model.__name__}:")]
for key in keys_to_remove:
self._cache.pop(key, None)
logger.debug(f"Cleared cache for {self.model.__name__}", record_id=record_id)
# ===== CONTEXT MANAGERS =====
@asynccontextmanager
async def transaction(self):
"""Context manager for explicit transaction handling"""
try:
yield self.session
await self.session.commit()
except Exception as e:
await self.session.rollback()
logger.error(f"Transaction failed for {self.model.__name__}", error=str(e))
raise