# ================================================================ # services/orders/app/repositories/base_repository.py # ================================================================ """ Base repository class for Orders Service """ from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union from uuid import UUID from sqlalchemy import select, update, delete, func, and_, or_ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload, joinedload import structlog from app.core.database import Base logger = structlog.get_logger() ModelType = TypeVar("ModelType", bound=Base) CreateSchemaType = TypeVar("CreateSchemaType") UpdateSchemaType = TypeVar("UpdateSchemaType") class BaseRepository(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): """Base repository with common CRUD operations""" def __init__(self, model: Type[ModelType]): self.model = model async def get( self, db: AsyncSession, id: UUID, tenant_id: Optional[UUID] = None ) -> Optional[ModelType]: """Get a single record by ID with optional tenant filtering""" try: query = select(self.model).where(self.model.id == id) # Add tenant filtering if tenant_id is provided and model has tenant_id field if tenant_id and hasattr(self.model, 'tenant_id'): query = query.where(self.model.tenant_id == tenant_id) result = await db.execute(query) return result.scalar_one_or_none() except Exception as e: logger.error("Error getting record", model=self.model.__name__, id=str(id), error=str(e)) raise async def get_by_field( self, db: AsyncSession, field_name: str, field_value: Any, tenant_id: Optional[UUID] = None ) -> Optional[ModelType]: """Get a single record by field value""" try: field = getattr(self.model, field_name) query = select(self.model).where(field == field_value) if tenant_id and hasattr(self.model, 'tenant_id'): query = query.where(self.model.tenant_id == tenant_id) result = await db.execute(query) return result.scalar_one_or_none() except Exception as e: logger.error("Error getting record by field", model=self.model.__name__, field_name=field_name, field_value=str(field_value), error=str(e)) raise async def get_multi( self, db: AsyncSession, tenant_id: Optional[UUID] = None, skip: int = 0, limit: int = 100, filters: Optional[Dict[str, Any]] = None, order_by: Optional[str] = None, order_desc: bool = False ) -> List[ModelType]: """Get multiple records with filtering, pagination, and sorting""" try: query = select(self.model) # Add tenant filtering if tenant_id and hasattr(self.model, 'tenant_id'): query = query.where(self.model.tenant_id == tenant_id) # Add additional filters if filters: for field_name, field_value in filters.items(): if hasattr(self.model, field_name): field = getattr(self.model, field_name) if isinstance(field_value, list): query = query.where(field.in_(field_value)) else: query = query.where(field == field_value) # Add ordering if order_by and hasattr(self.model, order_by): order_field = getattr(self.model, order_by) if order_desc: query = query.order_by(order_field.desc()) else: query = query.order_by(order_field) # Add pagination query = query.offset(skip).limit(limit) result = await db.execute(query) return result.scalars().all() except Exception as e: logger.error("Error getting multiple records", model=self.model.__name__, error=str(e)) raise async def count( self, db: AsyncSession, tenant_id: Optional[UUID] = None, filters: Optional[Dict[str, Any]] = None ) -> int: """Count records with optional filtering""" try: query = select(func.count()).select_from(self.model) # Add tenant filtering if tenant_id and hasattr(self.model, 'tenant_id'): query = query.where(self.model.tenant_id == tenant_id) # Add additional filters if filters: for field_name, field_value in filters.items(): if hasattr(self.model, field_name): field = getattr(self.model, field_name) if isinstance(field_value, list): query = query.where(field.in_(field_value)) else: query = query.where(field == field_value) result = await db.execute(query) return result.scalar() except Exception as e: logger.error("Error counting records", model=self.model.__name__, error=str(e)) raise async def create( self, db: AsyncSession, *, obj_in: CreateSchemaType, created_by: Optional[UUID] = None, tenant_id: Optional[UUID] = None ) -> ModelType: """Create a new record""" try: # Convert schema to dict if hasattr(obj_in, 'dict'): obj_data = obj_in.dict() else: obj_data = obj_in # Add tenant_id if the model supports it and it's provided if tenant_id and hasattr(self.model, 'tenant_id'): obj_data['tenant_id'] = tenant_id # Add created_by if the model supports it if created_by and hasattr(self.model, 'created_by'): obj_data['created_by'] = created_by # Create model instance db_obj = self.model(**obj_data) # Add to session and flush to get ID db.add(db_obj) await db.flush() await db.refresh(db_obj) logger.info("Record created", model=self.model.__name__, id=str(db_obj.id)) return db_obj except Exception as e: logger.error("Error creating record", model=self.model.__name__, error=str(e)) raise async def update( self, db: AsyncSession, *, db_obj: ModelType, obj_in: Union[UpdateSchemaType, Dict[str, Any]], updated_by: Optional[UUID] = None ) -> ModelType: """Update an existing record""" try: # Convert schema to dict if hasattr(obj_in, 'dict'): update_data = obj_in.dict(exclude_unset=True) else: update_data = obj_in # Add updated_by if the model supports it if updated_by and hasattr(self.model, 'updated_by'): update_data['updated_by'] = updated_by # Update fields for field, value in update_data.items(): if hasattr(db_obj, field): setattr(db_obj, field, value) # Flush changes await db.flush() await db.refresh(db_obj) logger.info("Record updated", model=self.model.__name__, id=str(db_obj.id)) return db_obj except Exception as e: logger.error("Error updating record", model=self.model.__name__, id=str(db_obj.id), error=str(e)) raise async def delete( self, db: AsyncSession, *, id: UUID, tenant_id: Optional[UUID] = None ) -> Optional[ModelType]: """Delete a record by ID""" try: # First get the record db_obj = await self.get(db, id=id, tenant_id=tenant_id) if not db_obj: return None # Delete the record await db.delete(db_obj) await db.flush() logger.info("Record deleted", model=self.model.__name__, id=str(id)) return db_obj except Exception as e: logger.error("Error deleting record", model=self.model.__name__, id=str(id), error=str(e)) raise async def exists( self, db: AsyncSession, id: UUID, tenant_id: Optional[UUID] = None ) -> bool: """Check if a record exists""" try: query = select(func.count()).select_from(self.model).where(self.model.id == id) if tenant_id and hasattr(self.model, 'tenant_id'): query = query.where(self.model.tenant_id == tenant_id) result = await db.execute(query) count = result.scalar() return count > 0 except Exception as e: logger.error("Error checking record existence", model=self.model.__name__, id=str(id), error=str(e)) raise