100 lines
3.5 KiB
Python
100 lines
3.5 KiB
Python
# services/suppliers/app/repositories/base.py
|
|
"""
|
|
Base repository class for common database operations
|
|
"""
|
|
|
|
from typing import TypeVar, Generic, List, Optional, Dict, Any
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import desc, asc, select, func
|
|
from uuid import UUID
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
class BaseRepository(Generic[T]):
|
|
"""Base repository with common CRUD operations"""
|
|
|
|
def __init__(self, model: type, db: AsyncSession):
|
|
self.model = model
|
|
self.db = db
|
|
|
|
async def create(self, obj_data: Dict[str, Any]) -> T:
|
|
"""Create a new record"""
|
|
db_obj = self.model(**obj_data)
|
|
self.db.add(db_obj)
|
|
await self.db.commit()
|
|
await self.db.refresh(db_obj)
|
|
return db_obj
|
|
|
|
async def get_by_id(self, record_id: UUID) -> Optional[T]:
|
|
"""Get record by ID"""
|
|
stmt = select(self.model).filter(self.model.id == record_id)
|
|
result = await self.db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_tenant_id(self, tenant_id: UUID, limit: int = 100, offset: int = 0) -> List[T]:
|
|
"""Get records by tenant ID with pagination"""
|
|
stmt = select(self.model).filter(
|
|
self.model.tenant_id == tenant_id
|
|
).limit(limit).offset(offset)
|
|
result = await self.db.execute(stmt)
|
|
return result.scalars().all()
|
|
|
|
async def update(self, record_id: UUID, update_data: Dict[str, Any]) -> Optional[T]:
|
|
"""Update record by ID"""
|
|
db_obj = await self.get_by_id(record_id)
|
|
if db_obj:
|
|
for key, value in update_data.items():
|
|
if hasattr(db_obj, key):
|
|
setattr(db_obj, key, value)
|
|
await self.db.commit()
|
|
await self.db.refresh(db_obj)
|
|
return db_obj
|
|
|
|
async def delete(self, record_id: UUID) -> bool:
|
|
"""Delete record by ID"""
|
|
db_obj = await self.get_by_id(record_id)
|
|
if db_obj:
|
|
await self.db.delete(db_obj)
|
|
await self.db.commit()
|
|
return True
|
|
return False
|
|
|
|
async def count_by_tenant(self, tenant_id: UUID) -> int:
|
|
"""Count records by tenant"""
|
|
stmt = select(func.count()).select_from(self.model).filter(
|
|
self.model.tenant_id == tenant_id
|
|
)
|
|
result = await self.db.execute(stmt)
|
|
return result.scalar() or 0
|
|
|
|
def list_with_filters(
|
|
self,
|
|
tenant_id: UUID,
|
|
filters: Optional[Dict[str, Any]] = None,
|
|
sort_by: str = "created_at",
|
|
sort_order: str = "desc",
|
|
limit: int = 100,
|
|
offset: int = 0
|
|
) -> List[T]:
|
|
"""List records with filtering and sorting"""
|
|
query = self.db.query(self.model).filter(self.model.tenant_id == tenant_id)
|
|
|
|
# Apply filters
|
|
if filters:
|
|
for key, value in filters.items():
|
|
if hasattr(self.model, key) and value is not None:
|
|
query = query.filter(getattr(self.model, key) == value)
|
|
|
|
# Apply sorting
|
|
if hasattr(self.model, sort_by):
|
|
if sort_order.lower() == "desc":
|
|
query = query.order_by(desc(getattr(self.model, sort_by)))
|
|
else:
|
|
query = query.order_by(asc(getattr(self.model, sort_by)))
|
|
|
|
return query.limit(limit).offset(offset).all()
|
|
|
|
def exists(self, record_id: UUID) -> bool:
|
|
"""Check if record exists"""
|
|
return self.db.query(self.model).filter(self.model.id == record_id).first() is not None |