Root cause: The validation in NotificationBaseRepository._validate_notification_data was checking enum objects against string lists, causing validation to fail when the EnhancedNotificationService passed NotificationType/NotificationPriority/NotificationStatus enum objects instead of their string values. The validation now properly handles both enum objects (by extracting their .value) and string values, fixing the "Invalid notification type" error from orchestrator. Changes: - Updated priority validation to handle enum objects - Updated notification type validation to handle enum objects - Updated status validation to handle enum objects Fixes the error: "Invalid notification data: ['Invalid notification type. Must be one of: ['email', 'whatsapp', 'push', 'sms']']"
265 lines
10 KiB
Python
265 lines
10 KiB
Python
"""
|
|
Base Repository for Notification Service
|
|
Service-specific repository base class with notification utilities
|
|
"""
|
|
|
|
from typing import Optional, List, Dict, Any, Type
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import text, and_
|
|
from datetime import datetime, timedelta
|
|
import structlog
|
|
|
|
from shared.database.repository import BaseRepository
|
|
from shared.database.exceptions import DatabaseError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class NotificationBaseRepository(BaseRepository):
|
|
"""Base repository for notification service with common notification operations"""
|
|
|
|
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
|
# Notifications change frequently, shorter cache time (5 minutes)
|
|
super().__init__(model, session, cache_ttl)
|
|
|
|
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
|
|
"""Get records by tenant ID"""
|
|
if hasattr(self.model, 'tenant_id'):
|
|
return await self.get_multi(
|
|
skip=skip,
|
|
limit=limit,
|
|
filters={"tenant_id": tenant_id},
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
return await self.get_multi(skip=skip, limit=limit)
|
|
|
|
async def get_by_user_id(self, user_id: str, skip: int = 0, limit: int = 100) -> List:
|
|
"""Get records by user ID (recipient or sender)"""
|
|
filters = {}
|
|
|
|
if hasattr(self.model, 'recipient_id'):
|
|
filters["recipient_id"] = user_id
|
|
elif hasattr(self.model, 'sender_id'):
|
|
filters["sender_id"] = user_id
|
|
elif hasattr(self.model, 'user_id'):
|
|
filters["user_id"] = user_id
|
|
|
|
if filters:
|
|
return await self.get_multi(
|
|
skip=skip,
|
|
limit=limit,
|
|
filters=filters,
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
return []
|
|
|
|
async def get_by_status(self, status: str, skip: int = 0, limit: int = 100) -> List:
|
|
"""Get records by status"""
|
|
if hasattr(self.model, 'status'):
|
|
return await self.get_multi(
|
|
skip=skip,
|
|
limit=limit,
|
|
filters={"status": status},
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
return await self.get_multi(skip=skip, limit=limit)
|
|
|
|
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
|
|
"""Get active records (if model has is_active field)"""
|
|
if hasattr(self.model, 'is_active'):
|
|
return await self.get_multi(
|
|
skip=skip,
|
|
limit=limit,
|
|
filters={"is_active": True},
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
return await self.get_multi(skip=skip, limit=limit)
|
|
|
|
async def get_recent_records(self, hours: int = 24, skip: int = 0, limit: int = 100) -> List:
|
|
"""Get records created in the last N hours"""
|
|
try:
|
|
cutoff_time = datetime.utcnow() - timedelta(hours=hours)
|
|
table_name = self.model.__tablename__
|
|
|
|
query_text = f"""
|
|
SELECT * FROM {table_name}
|
|
WHERE created_at >= :cutoff_time
|
|
ORDER BY created_at DESC
|
|
LIMIT :limit OFFSET :skip
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text), {
|
|
"cutoff_time": cutoff_time,
|
|
"limit": limit,
|
|
"skip": skip
|
|
})
|
|
|
|
records = []
|
|
for row in result.fetchall():
|
|
record_dict = dict(row._mapping)
|
|
record = self.model(**record_dict)
|
|
records.append(record)
|
|
|
|
return records
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get recent records",
|
|
model=self.model.__name__,
|
|
hours=hours,
|
|
error=str(e))
|
|
return []
|
|
|
|
async def cleanup_old_records(self, days_old: int = 90) -> int:
|
|
"""Clean up old notification records (90 days by default)"""
|
|
try:
|
|
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
|
table_name = self.model.__tablename__
|
|
|
|
# Only delete successfully processed or cancelled records that are old
|
|
conditions = [
|
|
"created_at < :cutoff_date"
|
|
]
|
|
|
|
# Add status condition if model has status field
|
|
if hasattr(self.model, 'status'):
|
|
conditions.append("status IN ('delivered', 'cancelled', 'failed')")
|
|
|
|
query_text = f"""
|
|
DELETE FROM {table_name}
|
|
WHERE {' AND '.join(conditions)}
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
|
deleted_count = result.rowcount
|
|
|
|
logger.info(f"Cleaned up old {self.model.__name__} records",
|
|
deleted_count=deleted_count,
|
|
days_old=days_old)
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to cleanup old records",
|
|
model=self.model.__name__,
|
|
error=str(e))
|
|
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
|
|
|
async def get_statistics_by_tenant(self, tenant_id: str) -> Dict[str, Any]:
|
|
"""Get statistics for a tenant"""
|
|
try:
|
|
table_name = self.model.__tablename__
|
|
|
|
# Get basic counts
|
|
total_records = await self.count(filters={"tenant_id": tenant_id})
|
|
|
|
# Get recent activity (records in last 24 hours)
|
|
twenty_four_hours_ago = datetime.utcnow() - timedelta(hours=24)
|
|
recent_query = text(f"""
|
|
SELECT COUNT(*) as count
|
|
FROM {table_name}
|
|
WHERE tenant_id = :tenant_id
|
|
AND created_at >= :twenty_four_hours_ago
|
|
""")
|
|
|
|
result = await self.session.execute(recent_query, {
|
|
"tenant_id": tenant_id,
|
|
"twenty_four_hours_ago": twenty_four_hours_ago
|
|
})
|
|
recent_records = result.scalar() or 0
|
|
|
|
# Get status breakdown if applicable
|
|
status_breakdown = {}
|
|
if hasattr(self.model, 'status'):
|
|
status_query = text(f"""
|
|
SELECT status, COUNT(*) as count
|
|
FROM {table_name}
|
|
WHERE tenant_id = :tenant_id
|
|
GROUP BY status
|
|
""")
|
|
|
|
result = await self.session.execute(status_query, {"tenant_id": tenant_id})
|
|
status_breakdown = {row.status: row.count for row in result.fetchall()}
|
|
|
|
return {
|
|
"total_records": total_records,
|
|
"recent_records_24h": recent_records,
|
|
"status_breakdown": status_breakdown
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get tenant statistics",
|
|
model=self.model.__name__,
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
return {
|
|
"total_records": 0,
|
|
"recent_records_24h": 0,
|
|
"status_breakdown": {}
|
|
}
|
|
|
|
def _validate_notification_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
|
|
"""Validate notification-related data"""
|
|
errors = []
|
|
|
|
for field in required_fields:
|
|
if field not in data or not data[field]:
|
|
errors.append(f"Missing required field: {field}")
|
|
|
|
# Validate tenant_id format if present
|
|
if "tenant_id" in data and data["tenant_id"]:
|
|
tenant_id = data["tenant_id"]
|
|
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
|
|
errors.append("Invalid tenant_id format")
|
|
|
|
# Validate user IDs if present
|
|
user_fields = ["user_id", "recipient_id", "sender_id"]
|
|
for field in user_fields:
|
|
if field in data and data[field]:
|
|
user_id = data[field]
|
|
if not isinstance(user_id, str) or len(user_id) < 1:
|
|
errors.append(f"Invalid {field} format")
|
|
|
|
# Validate email format if present
|
|
if "recipient_email" in data and data["recipient_email"]:
|
|
email = data["recipient_email"]
|
|
if "@" not in email or "." not in email.split("@")[-1]:
|
|
errors.append("Invalid email format")
|
|
|
|
# Validate phone format if present
|
|
if "recipient_phone" in data and data["recipient_phone"]:
|
|
phone = data["recipient_phone"]
|
|
if not isinstance(phone, str) or len(phone) < 9:
|
|
errors.append("Invalid phone format")
|
|
|
|
# Validate priority if present
|
|
if "priority" in data and data["priority"]:
|
|
from enum import Enum
|
|
priority_value = data["priority"].value if isinstance(data["priority"], Enum) else data["priority"]
|
|
valid_priorities = ["low", "normal", "high", "urgent"]
|
|
if priority_value not in valid_priorities:
|
|
errors.append(f"Invalid priority. Must be one of: {valid_priorities}")
|
|
|
|
# Validate notification type if present
|
|
if "type" in data and data["type"]:
|
|
from enum import Enum
|
|
type_value = data["type"].value if isinstance(data["type"], Enum) else data["type"]
|
|
valid_types = ["email", "whatsapp", "push", "sms"]
|
|
if type_value not in valid_types:
|
|
errors.append(f"Invalid notification type. Must be one of: {valid_types}")
|
|
|
|
# Validate status if present
|
|
if "status" in data and data["status"]:
|
|
from enum import Enum
|
|
status_value = data["status"].value if isinstance(data["status"], Enum) else data["status"]
|
|
valid_statuses = ["pending", "sent", "delivered", "failed", "cancelled"]
|
|
if status_value not in valid_statuses:
|
|
errors.append(f"Invalid status. Must be one of: {valid_statuses}")
|
|
|
|
return {
|
|
"is_valid": len(errors) == 0,
|
|
"errors": errors
|
|
} |