REFACTOR - Database logic
This commit is contained in:
422
shared/database/repository.py
Normal file
422
shared/database/repository.py
Normal file
@@ -0,0 +1,422 @@
|
||||
"""
|
||||
Base Repository Pattern for Database Operations
|
||||
Provides generic CRUD operations, query building, and caching
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, TypeVar, Generic, Type, Union
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy import select, update, delete, and_, or_, desc, asc, func, text
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from contextlib import asynccontextmanager
|
||||
import structlog
|
||||
|
||||
from .exceptions import (
|
||||
DatabaseError,
|
||||
RecordNotFoundError,
|
||||
DuplicateRecordError,
|
||||
ConstraintViolationError
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Type variables for generic repository
|
||||
Model = TypeVar('Model', bound=declarative_base())
|
||||
CreateSchema = TypeVar('CreateSchema')
|
||||
UpdateSchema = TypeVar('UpdateSchema')
|
||||
|
||||
|
||||
class BaseRepository(Generic[Model, CreateSchema, UpdateSchema], ABC):
|
||||
"""
|
||||
Base repository providing generic CRUD operations
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
session: Database session
|
||||
cache_ttl: Cache time-to-live in seconds (optional)
|
||||
"""
|
||||
|
||||
def __init__(self, model: Type[Model], session: AsyncSession, cache_ttl: Optional[int] = None):
|
||||
self.model = model
|
||||
self.session = session
|
||||
self.cache_ttl = cache_ttl
|
||||
self._cache = {} if cache_ttl else None
|
||||
|
||||
# ===== CORE CRUD OPERATIONS =====
|
||||
|
||||
async def create(self, obj_in: CreateSchema, **kwargs) -> Model:
|
||||
"""Create a new record"""
|
||||
try:
|
||||
# Convert schema to dict if needed
|
||||
if hasattr(obj_in, 'model_dump'):
|
||||
obj_data = obj_in.model_dump()
|
||||
elif hasattr(obj_in, 'dict'):
|
||||
obj_data = obj_in.dict()
|
||||
else:
|
||||
obj_data = obj_in
|
||||
|
||||
# Merge with additional kwargs
|
||||
obj_data.update(kwargs)
|
||||
|
||||
db_obj = self.model(**obj_data)
|
||||
self.session.add(db_obj)
|
||||
await self.session.flush() # Get ID without committing
|
||||
await self.session.refresh(db_obj)
|
||||
|
||||
logger.debug(f"Created {self.model.__name__}", record_id=getattr(db_obj, 'id', None))
|
||||
return db_obj
|
||||
|
||||
except IntegrityError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Integrity error creating {self.model.__name__}", error=str(e))
|
||||
raise DuplicateRecordError(f"Record with provided data already exists")
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}", error=str(e))
|
||||
raise DatabaseError(f"Failed to create record: {str(e)}")
|
||||
|
||||
async def get_by_id(self, record_id: Any) -> Optional[Model]:
|
||||
"""Get record by ID with optional caching"""
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
|
||||
# Check cache first
|
||||
if self._cache and cache_key in self._cache:
|
||||
logger.debug(f"Cache hit for {cache_key}")
|
||||
return self._cache[cache_key]
|
||||
|
||||
try:
|
||||
result = await self.session.execute(
|
||||
select(self.model).where(self.model.id == record_id)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
|
||||
# Cache the result
|
||||
if self._cache and record:
|
||||
self._cache[cache_key] = record
|
||||
|
||||
return record
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting {self.model.__name__} by ID",
|
||||
record_id=record_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to get record: {str(e)}")
|
||||
|
||||
async def get_by_field(self, field_name: str, value: Any) -> Optional[Model]:
|
||||
"""Get record by specific field"""
|
||||
try:
|
||||
result = await self.session.execute(
|
||||
select(self.model).where(getattr(self.model, field_name) == value)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
except AttributeError:
|
||||
raise ValueError(f"Field '{field_name}' not found in {self.model.__name__}")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting {self.model.__name__} by {field_name}",
|
||||
value=value, error=str(e))
|
||||
raise DatabaseError(f"Failed to get record: {str(e)}")
|
||||
|
||||
async def get_multi(
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
order_by: Optional[str] = None,
|
||||
order_desc: bool = False,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Model]:
|
||||
"""Get multiple records with pagination, sorting, and filtering"""
|
||||
try:
|
||||
query = select(self.model)
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
conditions = []
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field):
|
||||
if isinstance(value, list):
|
||||
conditions.append(getattr(self.model, field).in_(value))
|
||||
else:
|
||||
conditions.append(getattr(self.model, field) == value)
|
||||
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
# Apply ordering
|
||||
if order_by and hasattr(self.model, order_by):
|
||||
order_field = getattr(self.model, order_by)
|
||||
if order_desc:
|
||||
query = query.order_by(desc(order_field))
|
||||
else:
|
||||
query = query.order_by(asc(order_field))
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting multiple {self.model.__name__} records",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get records: {str(e)}")
|
||||
|
||||
async def update(self, record_id: Any, obj_in: UpdateSchema, **kwargs) -> Optional[Model]:
|
||||
"""Update record by ID"""
|
||||
try:
|
||||
# Convert schema to dict if needed
|
||||
if hasattr(obj_in, 'model_dump'):
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
elif hasattr(obj_in, 'dict'):
|
||||
update_data = obj_in.dict(exclude_unset=True)
|
||||
else:
|
||||
update_data = obj_in
|
||||
|
||||
# Merge with additional kwargs
|
||||
update_data.update(kwargs)
|
||||
|
||||
# Remove None values
|
||||
update_data = {k: v for k, v in update_data.items() if v is not None}
|
||||
|
||||
if not update_data:
|
||||
logger.warning(f"No data to update for {self.model.__name__}", record_id=record_id)
|
||||
return await self.get_by_id(record_id)
|
||||
|
||||
# Perform update
|
||||
result = await self.session.execute(
|
||||
update(self.model)
|
||||
.where(self.model.id == record_id)
|
||||
.values(**update_data)
|
||||
.returning(self.model)
|
||||
)
|
||||
|
||||
updated_record = result.scalar_one_or_none()
|
||||
|
||||
if not updated_record:
|
||||
raise RecordNotFoundError(f"{self.model.__name__} with id {record_id} not found")
|
||||
|
||||
# Clear cache
|
||||
if self._cache:
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
self._cache.pop(cache_key, None)
|
||||
|
||||
logger.debug(f"Updated {self.model.__name__}", record_id=record_id)
|
||||
return updated_record
|
||||
|
||||
except IntegrityError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Integrity error updating {self.model.__name__}",
|
||||
record_id=record_id, error=str(e))
|
||||
raise ConstraintViolationError(f"Update violates database constraints")
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}",
|
||||
record_id=record_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to update record: {str(e)}")
|
||||
|
||||
async def delete(self, record_id: Any) -> bool:
|
||||
"""Delete record by ID"""
|
||||
try:
|
||||
result = await self.session.execute(
|
||||
delete(self.model).where(self.model.id == record_id)
|
||||
)
|
||||
|
||||
deleted_count = result.rowcount
|
||||
|
||||
if deleted_count == 0:
|
||||
raise RecordNotFoundError(f"{self.model.__name__} with id {record_id} not found")
|
||||
|
||||
# Clear cache
|
||||
if self._cache:
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
self._cache.pop(cache_key, None)
|
||||
|
||||
logger.debug(f"Deleted {self.model.__name__}", record_id=record_id)
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error deleting {self.model.__name__}",
|
||||
record_id=record_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to delete record: {str(e)}")
|
||||
|
||||
# ===== ADVANCED QUERY OPERATIONS =====
|
||||
|
||||
async def count(self, filters: Optional[Dict[str, Any]] = None) -> int:
|
||||
"""Count records with optional filters"""
|
||||
try:
|
||||
query = select(func.count(self.model.id))
|
||||
|
||||
if filters:
|
||||
conditions = []
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field):
|
||||
if isinstance(value, list):
|
||||
conditions.append(getattr(self.model, field).in_(value))
|
||||
else:
|
||||
conditions.append(getattr(self.model, field) == value)
|
||||
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error counting {self.model.__name__} records", error=str(e))
|
||||
raise DatabaseError(f"Failed to count records: {str(e)}")
|
||||
|
||||
async def exists(self, record_id: Any) -> bool:
|
||||
"""Check if record exists by ID"""
|
||||
try:
|
||||
result = await self.session.execute(
|
||||
select(func.count(self.model.id)).where(self.model.id == record_id)
|
||||
)
|
||||
count = result.scalar() or 0
|
||||
return count > 0
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error checking existence of {self.model.__name__}",
|
||||
record_id=record_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to check record existence: {str(e)}")
|
||||
|
||||
async def bulk_create(self, objects: List[CreateSchema]) -> List[Model]:
|
||||
"""Create multiple records in bulk"""
|
||||
try:
|
||||
if not objects:
|
||||
return []
|
||||
|
||||
db_objects = []
|
||||
for obj_in in objects:
|
||||
if hasattr(obj_in, 'model_dump'):
|
||||
obj_data = obj_in.model_dump()
|
||||
elif hasattr(obj_in, 'dict'):
|
||||
obj_data = obj_in.dict()
|
||||
else:
|
||||
obj_data = obj_in
|
||||
|
||||
db_objects.append(self.model(**obj_data))
|
||||
|
||||
self.session.add_all(db_objects)
|
||||
await self.session.flush()
|
||||
|
||||
for db_obj in db_objects:
|
||||
await self.session.refresh(db_obj)
|
||||
|
||||
logger.debug(f"Bulk created {len(db_objects)} {self.model.__name__} records")
|
||||
return db_objects
|
||||
|
||||
except IntegrityError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Integrity error bulk creating {self.model.__name__}", error=str(e))
|
||||
raise DuplicateRecordError(f"One or more records already exist")
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error bulk creating {self.model.__name__}", error=str(e))
|
||||
raise DatabaseError(f"Failed to create records: {str(e)}")
|
||||
|
||||
async def bulk_update(self, updates: List[Dict[str, Any]]) -> int:
|
||||
"""Update multiple records in bulk"""
|
||||
try:
|
||||
if not updates:
|
||||
return 0
|
||||
|
||||
# Group updates by fields being updated for efficiency
|
||||
for update_data in updates:
|
||||
if 'id' not in update_data:
|
||||
raise ValueError("Each update must include 'id' field")
|
||||
|
||||
record_id = update_data.pop('id')
|
||||
await self.session.execute(
|
||||
update(self.model)
|
||||
.where(self.model.id == record_id)
|
||||
.values(**update_data)
|
||||
)
|
||||
|
||||
# Clear relevant cache entries
|
||||
if self._cache:
|
||||
for update_data in updates:
|
||||
record_id = update_data.get('id')
|
||||
if record_id:
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
self._cache.pop(cache_key, None)
|
||||
|
||||
logger.debug(f"Bulk updated {len(updates)} {self.model.__name__} records")
|
||||
return len(updates)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error bulk updating {self.model.__name__}", error=str(e))
|
||||
raise DatabaseError(f"Failed to update records: {str(e)}")
|
||||
|
||||
# ===== SEARCH AND QUERY BUILDING =====
|
||||
|
||||
async def search(
|
||||
self,
|
||||
search_term: str,
|
||||
search_fields: List[str],
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[Model]:
|
||||
"""Search records across multiple fields"""
|
||||
try:
|
||||
conditions = []
|
||||
for field in search_fields:
|
||||
if hasattr(self.model, field):
|
||||
field_obj = getattr(self.model, field)
|
||||
# Case-insensitive partial match
|
||||
conditions.append(field_obj.ilike(f"%{search_term}%"))
|
||||
|
||||
if not conditions:
|
||||
logger.warning(f"No valid search fields provided for {self.model.__name__}")
|
||||
return []
|
||||
|
||||
query = select(self.model).where(or_(*conditions)).offset(skip).limit(limit)
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error searching {self.model.__name__}",
|
||||
search_term=search_term, error=str(e))
|
||||
raise DatabaseError(f"Failed to search records: {str(e)}")
|
||||
|
||||
async def execute_raw_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""Execute raw SQL query (use with caution)"""
|
||||
try:
|
||||
result = await self.session.execute(text(query), params or {})
|
||||
return result
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error executing raw query", query=query, error=str(e))
|
||||
raise DatabaseError(f"Failed to execute query: {str(e)}")
|
||||
|
||||
# ===== CACHE MANAGEMENT =====
|
||||
|
||||
def clear_cache(self, record_id: Optional[Any] = None):
|
||||
"""Clear cache for specific record or all records"""
|
||||
if not self._cache:
|
||||
return
|
||||
|
||||
if record_id:
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
self._cache.pop(cache_key, None)
|
||||
else:
|
||||
# Clear all cache entries for this model
|
||||
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{self.model.__name__}:")]
|
||||
for key in keys_to_remove:
|
||||
self._cache.pop(key, None)
|
||||
|
||||
logger.debug(f"Cleared cache for {self.model.__name__}", record_id=record_id)
|
||||
|
||||
# ===== CONTEXT MANAGERS =====
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
"""Context manager for explicit transaction handling"""
|
||||
try:
|
||||
yield self.session
|
||||
await self.session.commit()
|
||||
except Exception as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Transaction failed for {self.model.__name__}", error=str(e))
|
||||
raise
|
||||
Reference in New Issue
Block a user