277 lines
11 KiB
Python
277 lines
11 KiB
Python
"""
|
|
User Repository
|
|
Repository for user operations with authentication-specific queries
|
|
"""
|
|
|
|
from typing import Optional, List, Dict, Any
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_, or_, func, desc, text
|
|
from datetime import datetime, timezone, timedelta
|
|
import structlog
|
|
|
|
from .base import AuthBaseRepository
|
|
from app.models.users import User
|
|
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class UserRepository(AuthBaseRepository):
|
|
"""Repository for user operations"""
|
|
|
|
def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
|
super().__init__(model, session, cache_ttl)
|
|
|
|
async def create_user(self, user_data: Dict[str, Any]) -> User:
|
|
"""Create a new user with validation"""
|
|
try:
|
|
# Validate user data
|
|
validation_result = self._validate_auth_data(
|
|
user_data,
|
|
["email", "hashed_password", "full_name", "role"]
|
|
)
|
|
|
|
if not validation_result["is_valid"]:
|
|
raise ValidationError(f"Invalid user data: {validation_result['errors']}")
|
|
|
|
# Check if user already exists
|
|
existing_user = await self.get_by_email(user_data["email"])
|
|
if existing_user:
|
|
raise DuplicateRecordError(f"User with email {user_data['email']} already exists")
|
|
|
|
# Create user
|
|
user = await self.create(user_data)
|
|
|
|
logger.info("User created successfully",
|
|
user_id=user.id,
|
|
email=user.email,
|
|
role=user.role)
|
|
|
|
return user
|
|
|
|
except (ValidationError, DuplicateRecordError):
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to create user",
|
|
email=user_data.get("email"),
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to create user: {str(e)}")
|
|
|
|
async def get_user_by_email(self, email: str) -> Optional[User]:
|
|
"""Get user by email address"""
|
|
return await self.get_by_email(email)
|
|
|
|
async def get_active_users(self, skip: int = 0, limit: int = 100) -> List[User]:
|
|
"""Get all active users"""
|
|
return await self.get_active_records(skip=skip, limit=limit)
|
|
|
|
async def authenticate_user(self, email: str, password: str) -> Optional[User]:
|
|
"""Authenticate user with email and plain password"""
|
|
try:
|
|
user = await self.get_by_email(email)
|
|
|
|
if not user:
|
|
logger.debug("User not found for authentication", email=email)
|
|
return None
|
|
|
|
if not user.is_active:
|
|
logger.debug("User account is inactive", email=email)
|
|
return None
|
|
|
|
# Verify password using security manager
|
|
from app.core.security import SecurityManager
|
|
if SecurityManager.verify_password(password, user.hashed_password):
|
|
# Update last login
|
|
await self.update_last_login(user.id)
|
|
logger.info("User authenticated successfully",
|
|
user_id=user.id,
|
|
email=email)
|
|
return user
|
|
|
|
logger.debug("Invalid password for user", email=email)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error("Authentication failed",
|
|
email=email,
|
|
error=str(e))
|
|
raise DatabaseError(f"Authentication failed: {str(e)}")
|
|
|
|
async def update_last_login(self, user_id: str) -> Optional[User]:
|
|
"""Update user's last login timestamp"""
|
|
try:
|
|
return await self.update(user_id, {
|
|
"last_login": datetime.now(timezone.utc)
|
|
})
|
|
except Exception as e:
|
|
logger.error("Failed to update last login",
|
|
user_id=user_id,
|
|
error=str(e))
|
|
# Don't raise here - last login update is not critical
|
|
return None
|
|
|
|
async def update_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> Optional[User]:
|
|
"""Update user profile information"""
|
|
try:
|
|
# Remove sensitive fields that shouldn't be updated via profile
|
|
profile_data.pop("id", None)
|
|
profile_data.pop("hashed_password", None)
|
|
profile_data.pop("created_at", None)
|
|
profile_data.pop("is_active", None)
|
|
|
|
# Validate email if being updated
|
|
if "email" in profile_data:
|
|
validation_result = self._validate_auth_data(
|
|
profile_data,
|
|
["email"]
|
|
)
|
|
if not validation_result["is_valid"]:
|
|
raise ValidationError(f"Invalid profile data: {validation_result['errors']}")
|
|
|
|
# Check for email conflicts
|
|
existing_user = await self.get_by_email(profile_data["email"])
|
|
if existing_user and str(existing_user.id) != str(user_id):
|
|
raise DuplicateRecordError(f"Email {profile_data['email']} is already in use")
|
|
|
|
updated_user = await self.update(user_id, profile_data)
|
|
|
|
if updated_user:
|
|
logger.info("User profile updated",
|
|
user_id=user_id,
|
|
updated_fields=list(profile_data.keys()))
|
|
|
|
return updated_user
|
|
|
|
except (ValidationError, DuplicateRecordError):
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to update user profile",
|
|
user_id=user_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to update profile: {str(e)}")
|
|
|
|
async def change_password(self, user_id: str, new_password_hash: str) -> bool:
|
|
"""Change user password"""
|
|
try:
|
|
updated_user = await self.update(user_id, {
|
|
"hashed_password": new_password_hash
|
|
})
|
|
|
|
if updated_user:
|
|
logger.info("Password changed successfully", user_id=user_id)
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to change password",
|
|
user_id=user_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to change password: {str(e)}")
|
|
|
|
async def verify_user_email(self, user_id: str) -> Optional[User]:
|
|
"""Mark user email as verified"""
|
|
try:
|
|
return await self.update(user_id, {
|
|
"is_verified": True
|
|
})
|
|
except Exception as e:
|
|
logger.error("Failed to verify user email",
|
|
user_id=user_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to verify email: {str(e)}")
|
|
|
|
async def deactivate_user(self, user_id: str) -> Optional[User]:
|
|
"""Deactivate user account"""
|
|
return await self.deactivate_record(user_id)
|
|
|
|
async def activate_user(self, user_id: str) -> Optional[User]:
|
|
"""Activate user account"""
|
|
return await self.activate_record(user_id)
|
|
|
|
async def get_users_by_role(self, role: str, skip: int = 0, limit: int = 100) -> List[User]:
|
|
"""Get users by role"""
|
|
try:
|
|
return await self.get_multi(
|
|
skip=skip,
|
|
limit=limit,
|
|
filters={"role": role, "is_active": True},
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
except Exception as e:
|
|
logger.error("Failed to get users by role",
|
|
role=role,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get users by role: {str(e)}")
|
|
|
|
async def search_users(self, search_term: str, skip: int = 0, limit: int = 50) -> List[User]:
|
|
"""Search users by email or full name"""
|
|
try:
|
|
return await self.search(
|
|
search_term=search_term,
|
|
search_fields=["email", "full_name"],
|
|
skip=skip,
|
|
limit=limit
|
|
)
|
|
except Exception as e:
|
|
logger.error("Failed to search users",
|
|
search_term=search_term,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to search users: {str(e)}")
|
|
|
|
async def get_user_statistics(self) -> Dict[str, Any]:
|
|
"""Get user statistics"""
|
|
try:
|
|
# Get basic counts
|
|
total_users = await self.count()
|
|
active_users = await self.count(filters={"is_active": True})
|
|
verified_users = await self.count(filters={"is_verified": True})
|
|
|
|
# Get users by role using raw query
|
|
role_query = text("""
|
|
SELECT role, COUNT(*) as count
|
|
FROM users
|
|
WHERE is_active = true
|
|
GROUP BY role
|
|
ORDER BY count DESC
|
|
""")
|
|
|
|
result = await self.session.execute(role_query)
|
|
role_stats = {row.role: row.count for row in result.fetchall()}
|
|
|
|
# Recent activity (users created in last 30 days)
|
|
thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30)
|
|
recent_users_query = text("""
|
|
SELECT COUNT(*) as count
|
|
FROM users
|
|
WHERE created_at >= :thirty_days_ago
|
|
""")
|
|
|
|
recent_result = await self.session.execute(
|
|
recent_users_query,
|
|
{"thirty_days_ago": thirty_days_ago}
|
|
)
|
|
recent_users = recent_result.scalar() or 0
|
|
|
|
return {
|
|
"total_users": total_users,
|
|
"active_users": active_users,
|
|
"inactive_users": total_users - active_users,
|
|
"verified_users": verified_users,
|
|
"unverified_users": active_users - verified_users,
|
|
"recent_registrations": recent_users,
|
|
"users_by_role": role_stats
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get user statistics", error=str(e))
|
|
return {
|
|
"total_users": 0,
|
|
"active_users": 0,
|
|
"inactive_users": 0,
|
|
"verified_users": 0,
|
|
"unverified_users": 0,
|
|
"recent_registrations": 0,
|
|
"users_by_role": {}
|
|
} |