101 lines
4.0 KiB
Python
101 lines
4.0 KiB
Python
"""
|
|
Base Repository for Auth Service
|
|
Service-specific repository base class with auth service utilities
|
|
"""
|
|
|
|
from typing import Optional, List, Dict, Any, Type
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from datetime import datetime, timezone
|
|
import structlog
|
|
|
|
from shared.database.repository import BaseRepository
|
|
from shared.database.exceptions import DatabaseError, ValidationError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class AuthBaseRepository(BaseRepository):
|
|
"""Base repository for auth service with common auth operations"""
|
|
|
|
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
|
# Auth data benefits from longer caching (10 minutes)
|
|
super().__init__(model, session, cache_ttl)
|
|
|
|
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
|
|
"""Get active records (if model has is_active field)"""
|
|
if hasattr(self.model, 'is_active'):
|
|
return await self.get_multi(
|
|
skip=skip,
|
|
limit=limit,
|
|
filters={"is_active": True},
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
return await self.get_multi(skip=skip, limit=limit)
|
|
|
|
async def get_by_email(self, email: str) -> Optional:
|
|
"""Get record by email (if model has email field)"""
|
|
if hasattr(self.model, 'email'):
|
|
return await self.get_by_field("email", email)
|
|
return None
|
|
|
|
async def get_by_username(self, username: str) -> Optional:
|
|
"""Get record by username (if model has username field)"""
|
|
if hasattr(self.model, 'username'):
|
|
return await self.get_by_field("username", username)
|
|
return None
|
|
|
|
async def deactivate_record(self, record_id: Any) -> Optional:
|
|
"""Deactivate a record instead of deleting it"""
|
|
if hasattr(self.model, 'is_active'):
|
|
return await self.update(record_id, {"is_active": False})
|
|
return await self.delete(record_id)
|
|
|
|
async def activate_record(self, record_id: Any) -> Optional:
|
|
"""Activate a record"""
|
|
if hasattr(self.model, 'is_active'):
|
|
return await self.update(record_id, {"is_active": True})
|
|
return await self.get_by_id(record_id)
|
|
|
|
async def cleanup_expired_records(self, field_name: str = "expires_at") -> int:
|
|
"""Clean up expired records (for tokens, sessions, etc.)"""
|
|
try:
|
|
if not hasattr(self.model, field_name):
|
|
logger.warning(f"Model {self.model.__name__} has no {field_name} field for cleanup")
|
|
return 0
|
|
|
|
# This would need custom implementation with raw SQL for date comparison
|
|
# For now, return 0 to indicate no cleanup performed
|
|
logger.info(f"Cleanup requested for {self.model.__name__} but not implemented")
|
|
return 0
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to cleanup expired records",
|
|
model=self.model.__name__,
|
|
error=str(e))
|
|
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
|
|
|
def _validate_auth_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
|
|
"""Validate authentication-related data"""
|
|
errors = []
|
|
|
|
for field in required_fields:
|
|
if field not in data or not data[field]:
|
|
errors.append(f"Missing required field: {field}")
|
|
|
|
# Validate email format if present
|
|
if "email" in data and data["email"]:
|
|
email = data["email"]
|
|
if "@" not in email or "." not in email.split("@")[-1]:
|
|
errors.append("Invalid email format")
|
|
|
|
# Validate password strength if present
|
|
if "password" in data and data["password"]:
|
|
password = data["password"]
|
|
if len(password) < 8:
|
|
errors.append("Password must be at least 8 characters long")
|
|
|
|
return {
|
|
"is_valid": len(errors) == 0,
|
|
"errors": errors
|
|
} |