234 lines
9.0 KiB
Python
234 lines
9.0 KiB
Python
"""
|
|
Base Repository for Tenant Service
|
|
Service-specific repository base class with tenant management utilities
|
|
"""
|
|
|
|
from typing import Optional, List, Dict, Any, Type
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import text
|
|
from datetime import datetime, timedelta
|
|
import structlog
|
|
import json
|
|
|
|
from shared.database.repository import BaseRepository
|
|
from shared.database.exceptions import DatabaseError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class TenantBaseRepository(BaseRepository):
|
|
"""Base repository for tenant service with common tenant operations"""
|
|
|
|
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
|
# Tenant data is relatively stable, medium cache time (10 minutes)
|
|
super().__init__(model, session, cache_ttl)
|
|
|
|
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
|
|
"""Get records by tenant ID"""
|
|
if hasattr(self.model, 'tenant_id'):
|
|
return await self.get_multi(
|
|
skip=skip,
|
|
limit=limit,
|
|
filters={"tenant_id": tenant_id},
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
return await self.get_multi(skip=skip, limit=limit)
|
|
|
|
async def get_by_user_id(self, user_id: str, skip: int = 0, limit: int = 100) -> List:
|
|
"""Get records by user ID (for cross-service references)"""
|
|
if hasattr(self.model, 'user_id'):
|
|
return await self.get_multi(
|
|
skip=skip,
|
|
limit=limit,
|
|
filters={"user_id": user_id},
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
elif hasattr(self.model, 'owner_id'):
|
|
return await self.get_multi(
|
|
skip=skip,
|
|
limit=limit,
|
|
filters={"owner_id": user_id},
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
return []
|
|
|
|
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
|
|
"""Get active records (if model has is_active field)"""
|
|
if hasattr(self.model, 'is_active'):
|
|
return await self.get_multi(
|
|
skip=skip,
|
|
limit=limit,
|
|
filters={"is_active": True},
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
return await self.get_multi(skip=skip, limit=limit)
|
|
|
|
async def deactivate_record(self, record_id: Any) -> Optional:
|
|
"""Deactivate a record instead of deleting it"""
|
|
if hasattr(self.model, 'is_active'):
|
|
return await self.update(record_id, {"is_active": False})
|
|
return await self.delete(record_id)
|
|
|
|
async def activate_record(self, record_id: Any) -> Optional:
|
|
"""Activate a record"""
|
|
if hasattr(self.model, 'is_active'):
|
|
return await self.update(record_id, {"is_active": True})
|
|
return await self.get_by_id(record_id)
|
|
|
|
async def cleanup_old_records(self, days_old: int = 365) -> int:
|
|
"""Clean up old tenant records (very conservative - 1 year)"""
|
|
try:
|
|
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
|
table_name = self.model.__tablename__
|
|
|
|
# Only delete inactive records that are very old
|
|
conditions = [
|
|
"created_at < :cutoff_date"
|
|
]
|
|
|
|
if hasattr(self.model, 'is_active'):
|
|
conditions.append("is_active = false")
|
|
|
|
query_text = f"""
|
|
DELETE FROM {table_name}
|
|
WHERE {' AND '.join(conditions)}
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
|
deleted_count = result.rowcount
|
|
|
|
logger.info(f"Cleaned up old {self.model.__name__} records",
|
|
deleted_count=deleted_count,
|
|
days_old=days_old)
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to cleanup old records",
|
|
model=self.model.__name__,
|
|
error=str(e))
|
|
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
|
|
|
async def get_statistics_by_tenant(self, tenant_id: str) -> Dict[str, Any]:
|
|
"""Get statistics for a tenant"""
|
|
try:
|
|
table_name = self.model.__tablename__
|
|
|
|
# Get basic counts
|
|
total_records = await self.count(filters={"tenant_id": tenant_id})
|
|
|
|
# Get active records if applicable
|
|
active_records = total_records
|
|
if hasattr(self.model, 'is_active'):
|
|
active_records = await self.count(filters={
|
|
"tenant_id": tenant_id,
|
|
"is_active": True
|
|
})
|
|
|
|
# Get recent activity (records in last 7 days)
|
|
seven_days_ago = datetime.utcnow() - timedelta(days=7)
|
|
recent_query = text(f"""
|
|
SELECT COUNT(*) as count
|
|
FROM {table_name}
|
|
WHERE tenant_id = :tenant_id
|
|
AND created_at >= :seven_days_ago
|
|
""")
|
|
|
|
result = await self.session.execute(recent_query, {
|
|
"tenant_id": tenant_id,
|
|
"seven_days_ago": seven_days_ago
|
|
})
|
|
recent_records = result.scalar() or 0
|
|
|
|
return {
|
|
"total_records": total_records,
|
|
"active_records": active_records,
|
|
"inactive_records": total_records - active_records,
|
|
"recent_records_7d": recent_records
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get tenant statistics",
|
|
model=self.model.__name__,
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
return {
|
|
"total_records": 0,
|
|
"active_records": 0,
|
|
"inactive_records": 0,
|
|
"recent_records_7d": 0
|
|
}
|
|
|
|
def _validate_tenant_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
|
|
"""Validate tenant-related data"""
|
|
errors = []
|
|
|
|
for field in required_fields:
|
|
if field not in data or not data[field]:
|
|
errors.append(f"Missing required field: {field}")
|
|
|
|
# Validate tenant_id format if present
|
|
if "tenant_id" in data and data["tenant_id"]:
|
|
tenant_id = data["tenant_id"]
|
|
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
|
|
errors.append("Invalid tenant_id format")
|
|
|
|
# Validate user_id format if present
|
|
if "user_id" in data and data["user_id"]:
|
|
user_id = data["user_id"]
|
|
if not isinstance(user_id, str) or len(user_id) < 1:
|
|
errors.append("Invalid user_id format")
|
|
|
|
# Validate owner_id format if present
|
|
if "owner_id" in data and data["owner_id"]:
|
|
owner_id = data["owner_id"]
|
|
if not isinstance(owner_id, str) or len(owner_id) < 1:
|
|
errors.append("Invalid owner_id format")
|
|
|
|
# Validate email format if present
|
|
if "email" in data and data["email"]:
|
|
email = data["email"]
|
|
if "@" not in email or "." not in email.split("@")[-1]:
|
|
errors.append("Invalid email format")
|
|
|
|
# Validate phone format if present (basic validation)
|
|
if "phone" in data and data["phone"]:
|
|
phone = data["phone"]
|
|
if not isinstance(phone, str) or len(phone) < 9:
|
|
errors.append("Invalid phone format")
|
|
|
|
# Validate coordinates if present
|
|
if "latitude" in data and data["latitude"] is not None:
|
|
try:
|
|
lat = float(data["latitude"])
|
|
if lat < -90 or lat > 90:
|
|
errors.append("Invalid latitude - must be between -90 and 90")
|
|
except (ValueError, TypeError):
|
|
errors.append("Invalid latitude format")
|
|
|
|
if "longitude" in data and data["longitude"] is not None:
|
|
try:
|
|
lng = float(data["longitude"])
|
|
if lng < -180 or lng > 180:
|
|
errors.append("Invalid longitude - must be between -180 and 180")
|
|
except (ValueError, TypeError):
|
|
errors.append("Invalid longitude format")
|
|
|
|
# Validate JSON fields
|
|
json_fields = ["permissions"]
|
|
for field in json_fields:
|
|
if field in data and data[field]:
|
|
if isinstance(data[field], str):
|
|
try:
|
|
json.loads(data[field])
|
|
except json.JSONDecodeError:
|
|
errors.append(f"Invalid JSON format in {field}")
|
|
|
|
return {
|
|
"is_valid": len(errors) == 0,
|
|
"errors": errors
|
|
} |