167 lines
6.1 KiB
Python
167 lines
6.1 KiB
Python
"""
|
|
Base Repository for Data Service
|
|
Service-specific repository base class with data service utilities
|
|
"""
|
|
|
|
from typing import Optional, List, Dict, Any, Type, TypeVar, Generic
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from datetime import datetime, timezone
|
|
import structlog
|
|
|
|
from shared.database.repository import BaseRepository
|
|
from shared.database.exceptions import DatabaseError, ValidationError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
# Type variables for the data service repository
|
|
Model = TypeVar('Model')
|
|
CreateSchema = TypeVar('CreateSchema')
|
|
UpdateSchema = TypeVar('UpdateSchema')
|
|
|
|
|
|
class DataBaseRepository(BaseRepository[Model, CreateSchema, UpdateSchema], Generic[Model, CreateSchema, UpdateSchema]):
|
|
"""Base repository for data service with common data operations"""
|
|
|
|
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
|
super().__init__(model, session, cache_ttl)
|
|
|
|
async def get_by_tenant_id(
|
|
self,
|
|
tenant_id: str,
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> List:
|
|
"""Get records filtered by tenant_id"""
|
|
return await self.get_multi(
|
|
skip=skip,
|
|
limit=limit,
|
|
filters={"tenant_id": tenant_id}
|
|
)
|
|
|
|
async def get_by_date_range(
|
|
self,
|
|
tenant_id: str,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> List:
|
|
"""Get records filtered by tenant and date range"""
|
|
try:
|
|
filters = {"tenant_id": tenant_id}
|
|
|
|
# Build date range filter
|
|
if start_date or end_date:
|
|
if not hasattr(self.model, 'date'):
|
|
raise ValidationError("Model does not have 'date' field for date filtering")
|
|
|
|
# This would need a more complex implementation for date ranges
|
|
# For now, we'll use the basic filter
|
|
if start_date and end_date:
|
|
# Would need custom query building for date ranges
|
|
pass
|
|
|
|
return await self.get_multi(
|
|
skip=skip,
|
|
limit=limit,
|
|
filters=filters,
|
|
order_by="date",
|
|
order_desc=True
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get records by date range",
|
|
tenant_id=tenant_id,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
error=str(e))
|
|
raise DatabaseError(f"Date range query failed: {str(e)}")
|
|
|
|
async def count_by_tenant(self, tenant_id: str) -> int:
|
|
"""Count records for a specific tenant"""
|
|
return await self.count(filters={"tenant_id": tenant_id})
|
|
|
|
async def validate_tenant_access(self, tenant_id: str, record_id: Any) -> bool:
|
|
"""Validate that a record belongs to the specified tenant"""
|
|
try:
|
|
record = await self.get_by_id(record_id)
|
|
if not record:
|
|
return False
|
|
|
|
# Check if record has tenant_id field and matches
|
|
if hasattr(record, 'tenant_id'):
|
|
return str(record.tenant_id) == str(tenant_id)
|
|
|
|
return True # If no tenant_id field, allow access
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to validate tenant access",
|
|
tenant_id=tenant_id,
|
|
record_id=record_id,
|
|
error=str(e))
|
|
return False
|
|
|
|
async def get_tenant_stats(self, tenant_id: str) -> Dict[str, Any]:
|
|
"""Get statistics for a specific tenant"""
|
|
try:
|
|
total_records = await self.count_by_tenant(tenant_id)
|
|
|
|
# Get recent activity (if model has created_at)
|
|
recent_records = 0
|
|
if hasattr(self.model, 'created_at'):
|
|
# This would need custom query for date filtering
|
|
# For now, return basic stats
|
|
pass
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"total_records": total_records,
|
|
"recent_records": recent_records,
|
|
"model_type": self.model.__name__
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get tenant statistics",
|
|
tenant_id=tenant_id, error=str(e))
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"total_records": 0,
|
|
"recent_records": 0,
|
|
"model_type": self.model.__name__,
|
|
"error": str(e)
|
|
}
|
|
|
|
async def cleanup_old_records(
|
|
self,
|
|
tenant_id: str,
|
|
days_old: int = 365,
|
|
batch_size: int = 1000
|
|
) -> int:
|
|
"""Clean up old records for a tenant (if model has date/created_at field)"""
|
|
try:
|
|
if not hasattr(self.model, 'created_at') and not hasattr(self.model, 'date'):
|
|
logger.warning(f"Model {self.model.__name__} has no date field for cleanup")
|
|
return 0
|
|
|
|
# This would need custom implementation with raw SQL
|
|
# For now, return 0 to indicate no cleanup performed
|
|
logger.info(f"Cleanup requested for {self.model.__name__} but not implemented")
|
|
return 0
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to cleanup old records",
|
|
tenant_id=tenant_id,
|
|
days_old=days_old,
|
|
error=str(e))
|
|
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
|
|
|
def _ensure_utc_datetime(self, dt: Optional[datetime]) -> Optional[datetime]:
|
|
"""Ensure datetime is UTC timezone aware"""
|
|
if dt is None:
|
|
return None
|
|
|
|
if dt.tzinfo is None:
|
|
# Assume naive datetime is UTC
|
|
return dt.replace(tzinfo=timezone.utc)
|
|
|
|
return dt.astimezone(timezone.utc) |