""" Base Repository Pattern for Database Operations Provides generic CRUD operations, query building, and caching """ from typing import Optional, List, Dict, Any, TypeVar, Generic, Type, Union from abc import ABC, abstractmethod from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import declarative_base from sqlalchemy import select, update, delete, and_, or_, desc, asc, func, text from sqlalchemy.exc import IntegrityError, SQLAlchemyError from contextlib import asynccontextmanager import structlog from .exceptions import ( DatabaseError, RecordNotFoundError, DuplicateRecordError, ConstraintViolationError ) logger = structlog.get_logger() # Type variables for generic repository Model = TypeVar('Model', bound=declarative_base()) CreateSchema = TypeVar('CreateSchema') UpdateSchema = TypeVar('UpdateSchema') class BaseRepository(Generic[Model, CreateSchema, UpdateSchema], ABC): """ Base repository providing generic CRUD operations Args: model: SQLAlchemy model class session: Database session cache_ttl: Cache time-to-live in seconds (optional) """ def __init__(self, model: Type[Model], session: AsyncSession, cache_ttl: Optional[int] = None): self.model = model self.session = session self.cache_ttl = cache_ttl self._cache = {} if cache_ttl else None # ===== CORE CRUD OPERATIONS ===== async def create(self, obj_in: CreateSchema, **kwargs) -> Model: """Create a new record""" try: # Convert schema to dict if needed if hasattr(obj_in, 'model_dump'): obj_data = obj_in.model_dump() elif hasattr(obj_in, 'dict'): obj_data = obj_in.dict() else: obj_data = obj_in # Merge with additional kwargs obj_data.update(kwargs) db_obj = self.model(**obj_data) self.session.add(db_obj) await self.session.flush() # Get ID without committing await self.session.refresh(db_obj) logger.debug(f"Created {self.model.__name__}", record_id=getattr(db_obj, 'id', None)) return db_obj except IntegrityError as e: await self.session.rollback() logger.error(f"Integrity error creating {self.model.__name__}", error=str(e)) raise DuplicateRecordError(f"Record with provided data already exists") except SQLAlchemyError as e: await self.session.rollback() logger.error(f"Database error creating {self.model.__name__}", error=str(e)) raise DatabaseError(f"Failed to create record: {str(e)}") async def get_by_id(self, record_id: Any) -> Optional[Model]: """Get record by ID with optional caching""" cache_key = f"{self.model.__name__}:{record_id}" # Check cache first if self._cache and cache_key in self._cache: logger.debug(f"Cache hit for {cache_key}") return self._cache[cache_key] try: result = await self.session.execute( select(self.model).where(self.model.id == record_id) ) record = result.scalar_one_or_none() # Cache the result if self._cache and record: self._cache[cache_key] = record return record except SQLAlchemyError as e: logger.error(f"Database error getting {self.model.__name__} by ID", record_id=record_id, error=str(e)) raise DatabaseError(f"Failed to get record: {str(e)}") async def get_by_field(self, field_name: str, value: Any) -> Optional[Model]: """Get record by specific field""" try: result = await self.session.execute( select(self.model).where(getattr(self.model, field_name) == value) ) return result.scalar_one_or_none() except AttributeError: raise ValueError(f"Field '{field_name}' not found in {self.model.__name__}") except SQLAlchemyError as e: logger.error(f"Database error getting {self.model.__name__} by {field_name}", value=value, error=str(e)) raise DatabaseError(f"Failed to get record: {str(e)}") async def get_multi( self, skip: int = 0, limit: int = 100, order_by: Optional[str] = None, order_desc: bool = False, filters: Optional[Dict[str, Any]] = None ) -> List[Model]: """Get multiple records with pagination, sorting, and filtering""" try: query = select(self.model) # Apply filters if filters: conditions = [] for field, value in filters.items(): if hasattr(self.model, field): if isinstance(value, list): conditions.append(getattr(self.model, field).in_(value)) else: conditions.append(getattr(self.model, field) == value) if conditions: query = query.where(and_(*conditions)) # Apply ordering if order_by and hasattr(self.model, order_by): order_field = getattr(self.model, order_by) if order_desc: query = query.order_by(desc(order_field)) else: query = query.order_by(asc(order_field)) # Apply pagination query = query.offset(skip).limit(limit) result = await self.session.execute(query) return result.scalars().all() except SQLAlchemyError as e: logger.error(f"Database error getting multiple {self.model.__name__} records", error=str(e)) raise DatabaseError(f"Failed to get records: {str(e)}") async def update(self, record_id: Any, obj_in: UpdateSchema, **kwargs) -> Optional[Model]: """Update record by ID""" try: # Convert schema to dict if needed if hasattr(obj_in, 'model_dump'): update_data = obj_in.model_dump(exclude_unset=True) elif hasattr(obj_in, 'dict'): update_data = obj_in.dict(exclude_unset=True) else: update_data = obj_in # Merge with additional kwargs update_data.update(kwargs) # Remove None values update_data = {k: v for k, v in update_data.items() if v is not None} if not update_data: logger.warning(f"No data to update for {self.model.__name__}", record_id=record_id) return await self.get_by_id(record_id) # Perform update result = await self.session.execute( update(self.model) .where(self.model.id == record_id) .values(**update_data) .returning(self.model) ) updated_record = result.scalar_one_or_none() if not updated_record: raise RecordNotFoundError(f"{self.model.__name__} with id {record_id} not found") # Clear cache if self._cache: cache_key = f"{self.model.__name__}:{record_id}" self._cache.pop(cache_key, None) logger.debug(f"Updated {self.model.__name__}", record_id=record_id) return updated_record except IntegrityError as e: await self.session.rollback() logger.error(f"Integrity error updating {self.model.__name__}", record_id=record_id, error=str(e)) raise ConstraintViolationError(f"Update violates database constraints") except SQLAlchemyError as e: await self.session.rollback() logger.error(f"Database error updating {self.model.__name__}", record_id=record_id, error=str(e)) raise DatabaseError(f"Failed to update record: {str(e)}") async def delete(self, record_id: Any) -> bool: """Delete record by ID""" try: result = await self.session.execute( delete(self.model).where(self.model.id == record_id) ) deleted_count = result.rowcount if deleted_count == 0: raise RecordNotFoundError(f"{self.model.__name__} with id {record_id} not found") # Clear cache if self._cache: cache_key = f"{self.model.__name__}:{record_id}" self._cache.pop(cache_key, None) logger.debug(f"Deleted {self.model.__name__}", record_id=record_id) return True except SQLAlchemyError as e: await self.session.rollback() logger.error(f"Database error deleting {self.model.__name__}", record_id=record_id, error=str(e)) raise DatabaseError(f"Failed to delete record: {str(e)}") # ===== ADVANCED QUERY OPERATIONS ===== async def count(self, filters: Optional[Dict[str, Any]] = None) -> int: """Count records with optional filters""" try: query = select(func.count(self.model.id)) if filters: conditions = [] for field, value in filters.items(): if hasattr(self.model, field): if isinstance(value, list): conditions.append(getattr(self.model, field).in_(value)) else: conditions.append(getattr(self.model, field) == value) if conditions: query = query.where(and_(*conditions)) result = await self.session.execute(query) return result.scalar() or 0 except SQLAlchemyError as e: logger.error(f"Database error counting {self.model.__name__} records", error=str(e)) raise DatabaseError(f"Failed to count records: {str(e)}") async def exists(self, record_id: Any) -> bool: """Check if record exists by ID""" try: result = await self.session.execute( select(func.count(self.model.id)).where(self.model.id == record_id) ) count = result.scalar() or 0 return count > 0 except SQLAlchemyError as e: logger.error(f"Database error checking existence of {self.model.__name__}", record_id=record_id, error=str(e)) raise DatabaseError(f"Failed to check record existence: {str(e)}") async def bulk_create(self, objects: List[CreateSchema]) -> List[Model]: """Create multiple records in bulk""" try: if not objects: return [] db_objects = [] for obj_in in objects: if hasattr(obj_in, 'model_dump'): obj_data = obj_in.model_dump() elif hasattr(obj_in, 'dict'): obj_data = obj_in.dict() else: obj_data = obj_in db_objects.append(self.model(**obj_data)) self.session.add_all(db_objects) await self.session.flush() # Skip expensive individual refresh operations for large datasets # Only refresh if we have a small number of objects if len(db_objects) <= 100: for db_obj in db_objects: await self.session.refresh(db_obj) else: # For large datasets, just log without refresh to prevent memory issues logger.debug(f"Skipped individual refresh for large bulk operation ({len(db_objects)} records)") logger.debug(f"Bulk created {len(db_objects)} {self.model.__name__} records") return db_objects except IntegrityError as e: await self.session.rollback() logger.error(f"Integrity error bulk creating {self.model.__name__}", error=str(e)) raise DuplicateRecordError(f"One or more records already exist") except SQLAlchemyError as e: await self.session.rollback() logger.error(f"Database error bulk creating {self.model.__name__}", error=str(e)) raise DatabaseError(f"Failed to create records: {str(e)}") async def bulk_update(self, updates: List[Dict[str, Any]]) -> int: """Update multiple records in bulk""" try: if not updates: return 0 # Group updates by fields being updated for efficiency for update_data in updates: if 'id' not in update_data: raise ValueError("Each update must include 'id' field") record_id = update_data.pop('id') await self.session.execute( update(self.model) .where(self.model.id == record_id) .values(**update_data) ) # Clear relevant cache entries if self._cache: for update_data in updates: record_id = update_data.get('id') if record_id: cache_key = f"{self.model.__name__}:{record_id}" self._cache.pop(cache_key, None) logger.debug(f"Bulk updated {len(updates)} {self.model.__name__} records") return len(updates) except SQLAlchemyError as e: await self.session.rollback() logger.error(f"Database error bulk updating {self.model.__name__}", error=str(e)) raise DatabaseError(f"Failed to update records: {str(e)}") # ===== SEARCH AND QUERY BUILDING ===== async def search( self, search_term: str, search_fields: List[str], skip: int = 0, limit: int = 100 ) -> List[Model]: """Search records across multiple fields""" try: conditions = [] for field in search_fields: if hasattr(self.model, field): field_obj = getattr(self.model, field) # Case-insensitive partial match conditions.append(field_obj.ilike(f"%{search_term}%")) if not conditions: logger.warning(f"No valid search fields provided for {self.model.__name__}") return [] query = select(self.model).where(or_(*conditions)).offset(skip).limit(limit) result = await self.session.execute(query) return result.scalars().all() except SQLAlchemyError as e: logger.error(f"Database error searching {self.model.__name__}", search_term=search_term, error=str(e)) raise DatabaseError(f"Failed to search records: {str(e)}") async def execute_raw_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> Any: """Execute raw SQL query (use with caution)""" try: result = await self.session.execute(text(query), params or {}) return result except SQLAlchemyError as e: logger.error(f"Database error executing raw query", query=query, error=str(e)) raise DatabaseError(f"Failed to execute query: {str(e)}") # ===== CACHE MANAGEMENT ===== def clear_cache(self, record_id: Optional[Any] = None): """Clear cache for specific record or all records""" if not self._cache: return if record_id: cache_key = f"{self.model.__name__}:{record_id}" self._cache.pop(cache_key, None) else: # Clear all cache entries for this model keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{self.model.__name__}:")] for key in keys_to_remove: self._cache.pop(key, None) logger.debug(f"Cleared cache for {self.model.__name__}", record_id=record_id) # ===== CONTEXT MANAGERS ===== @asynccontextmanager async def transaction(self): """Context manager for explicit transaction handling""" try: yield self.session await self.session.commit() except Exception as e: await self.session.rollback() logger.error(f"Transaction failed for {self.model.__name__}", error=str(e)) raise