REFACTOR - Database logic

This commit is contained in:
Urtzi Alfaro
2025-08-08 09:08:41 +02:00
parent 0154365bfc
commit 488bb3ef93
113 changed files with 22842 additions and 6503 deletions

View File

@@ -0,0 +1,14 @@
"""
Auth Service Repositories
Repository implementations for authentication service
"""
from .base import AuthBaseRepository
from .user_repository import UserRepository
from .token_repository import TokenRepository
__all__ = [
"AuthBaseRepository",
"UserRepository",
"TokenRepository"
]

View File

@@ -0,0 +1,101 @@
"""
Base Repository for Auth Service
Service-specific repository base class with auth service utilities
"""
from typing import Optional, List, Dict, Any, Type
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime, timezone
import structlog
from shared.database.repository import BaseRepository
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class AuthBaseRepository(BaseRepository):
"""Base repository for auth service with common auth operations"""
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Auth data benefits from longer caching (10 minutes)
super().__init__(model, session, cache_ttl)
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
"""Get active records (if model has is_active field)"""
if hasattr(self.model, 'is_active'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"is_active": True},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_email(self, email: str) -> Optional:
"""Get record by email (if model has email field)"""
if hasattr(self.model, 'email'):
return await self.get_by_field("email", email)
return None
async def get_by_username(self, username: str) -> Optional:
"""Get record by username (if model has username field)"""
if hasattr(self.model, 'username'):
return await self.get_by_field("username", username)
return None
async def deactivate_record(self, record_id: Any) -> Optional:
"""Deactivate a record instead of deleting it"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": False})
return await self.delete(record_id)
async def activate_record(self, record_id: Any) -> Optional:
"""Activate a record"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": True})
return await self.get_by_id(record_id)
async def cleanup_expired_records(self, field_name: str = "expires_at") -> int:
"""Clean up expired records (for tokens, sessions, etc.)"""
try:
if not hasattr(self.model, field_name):
logger.warning(f"Model {self.model.__name__} has no {field_name} field for cleanup")
return 0
# This would need custom implementation with raw SQL for date comparison
# For now, return 0 to indicate no cleanup performed
logger.info(f"Cleanup requested for {self.model.__name__} but not implemented")
return 0
except Exception as e:
logger.error("Failed to cleanup expired records",
model=self.model.__name__,
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
def _validate_auth_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
"""Validate authentication-related data"""
errors = []
for field in required_fields:
if field not in data or not data[field]:
errors.append(f"Missing required field: {field}")
# Validate email format if present
if "email" in data and data["email"]:
email = data["email"]
if "@" not in email or "." not in email.split("@")[-1]:
errors.append("Invalid email format")
# Validate password strength if present
if "password" in data and data["password"]:
password = data["password"]
if len(password) < 8:
errors.append("Password must be at least 8 characters long")
return {
"is_valid": len(errors) == 0,
"errors": errors
}

View File

@@ -0,0 +1,269 @@
"""
Token Repository
Repository for refresh token operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text
from datetime import datetime, timezone, timedelta
import structlog
from .base import AuthBaseRepository
from app.models.users import RefreshToken
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class TokenRepository(AuthBaseRepository):
"""Repository for refresh token operations"""
def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Tokens change frequently, shorter cache time
super().__init__(model, session, cache_ttl)
async def create_token(self, token_data: Dict[str, Any]) -> RefreshToken:
"""Create a new refresh token from dictionary data"""
return await self.create(token_data)
async def create_refresh_token(
self,
user_id: str,
token: str,
expires_at: datetime
) -> RefreshToken:
"""Create a new refresh token"""
try:
token_data = {
"user_id": user_id,
"token": token,
"expires_at": expires_at,
"is_revoked": False
}
refresh_token = await self.create(token_data)
logger.debug("Refresh token created",
user_id=user_id,
token_id=refresh_token.id,
expires_at=expires_at)
return refresh_token
except Exception as e:
logger.error("Failed to create refresh token",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to create refresh token: {str(e)}")
async def get_token_by_value(self, token: str) -> Optional[RefreshToken]:
"""Get refresh token by token value"""
try:
return await self.get_by_field("token", token)
except Exception as e:
logger.error("Failed to get token by value", error=str(e))
raise DatabaseError(f"Failed to get token: {str(e)}")
async def get_active_tokens_for_user(self, user_id: str) -> List[RefreshToken]:
"""Get all active (non-revoked, non-expired) tokens for a user"""
try:
now = datetime.now(timezone.utc)
# Use raw query for complex filtering
query = text("""
SELECT * FROM refresh_tokens
WHERE user_id = :user_id
AND is_revoked = false
AND expires_at > :now
ORDER BY created_at DESC
""")
result = await self.session.execute(query, {
"user_id": user_id,
"now": now
})
# Convert rows to RefreshToken objects
tokens = []
for row in result.fetchall():
token = RefreshToken(
id=row.id,
user_id=row.user_id,
token=row.token,
expires_at=row.expires_at,
is_revoked=row.is_revoked,
created_at=row.created_at,
revoked_at=row.revoked_at
)
tokens.append(token)
return tokens
except Exception as e:
logger.error("Failed to get active tokens for user",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to get active tokens: {str(e)}")
async def revoke_token(self, token_id: str) -> Optional[RefreshToken]:
"""Revoke a refresh token"""
try:
return await self.update(token_id, {
"is_revoked": True,
"revoked_at": datetime.now(timezone.utc)
})
except Exception as e:
logger.error("Failed to revoke token",
token_id=token_id,
error=str(e))
raise DatabaseError(f"Failed to revoke token: {str(e)}")
async def revoke_all_user_tokens(self, user_id: str) -> int:
"""Revoke all tokens for a user"""
try:
# Use bulk update for efficiency
now = datetime.now(timezone.utc)
query = text("""
UPDATE refresh_tokens
SET is_revoked = true, revoked_at = :revoked_at
WHERE user_id = :user_id AND is_revoked = false
""")
result = await self.session.execute(query, {
"user_id": user_id,
"revoked_at": now
})
revoked_count = result.rowcount
logger.info("Revoked all user tokens",
user_id=user_id,
revoked_count=revoked_count)
return revoked_count
except Exception as e:
logger.error("Failed to revoke all user tokens",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to revoke user tokens: {str(e)}")
async def is_token_valid(self, token: str) -> bool:
"""Check if a token is valid (exists, not revoked, not expired)"""
try:
refresh_token = await self.get_token_by_value(token)
if not refresh_token:
return False
if refresh_token.is_revoked:
return False
if refresh_token.expires_at < datetime.now(timezone.utc):
return False
return True
except Exception as e:
logger.error("Failed to validate token", error=str(e))
return False
async def cleanup_expired_tokens(self) -> int:
"""Clean up expired refresh tokens"""
try:
now = datetime.now(timezone.utc)
# Delete expired tokens
query = text("""
DELETE FROM refresh_tokens
WHERE expires_at < :now
""")
result = await self.session.execute(query, {"now": now})
deleted_count = result.rowcount
logger.info("Cleaned up expired tokens",
deleted_count=deleted_count)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup expired tokens", error=str(e))
raise DatabaseError(f"Token cleanup failed: {str(e)}")
async def cleanup_old_revoked_tokens(self, days_old: int = 30) -> int:
"""Clean up old revoked tokens"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
query = text("""
DELETE FROM refresh_tokens
WHERE is_revoked = true
AND revoked_at < :cutoff_date
""")
result = await self.session.execute(query, {
"cutoff_date": cutoff_date
})
deleted_count = result.rowcount
logger.info("Cleaned up old revoked tokens",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old revoked tokens",
days_old=days_old,
error=str(e))
raise DatabaseError(f"Revoked token cleanup failed: {str(e)}")
async def get_token_statistics(self) -> Dict[str, Any]:
"""Get token statistics"""
try:
now = datetime.now(timezone.utc)
# Get counts with raw queries
stats_query = text("""
SELECT
COUNT(*) as total_tokens,
COUNT(CASE WHEN is_revoked = false AND expires_at > :now THEN 1 END) as active_tokens,
COUNT(CASE WHEN is_revoked = true THEN 1 END) as revoked_tokens,
COUNT(CASE WHEN expires_at <= :now THEN 1 END) as expired_tokens,
COUNT(DISTINCT user_id) as users_with_tokens
FROM refresh_tokens
""")
result = await self.session.execute(stats_query, {"now": now})
row = result.fetchone()
if row:
return {
"total_tokens": row.total_tokens,
"active_tokens": row.active_tokens,
"revoked_tokens": row.revoked_tokens,
"expired_tokens": row.expired_tokens,
"users_with_tokens": row.users_with_tokens
}
return {
"total_tokens": 0,
"active_tokens": 0,
"revoked_tokens": 0,
"expired_tokens": 0,
"users_with_tokens": 0
}
except Exception as e:
logger.error("Failed to get token statistics", error=str(e))
return {
"total_tokens": 0,
"active_tokens": 0,
"revoked_tokens": 0,
"expired_tokens": 0,
"users_with_tokens": 0
}

View File

@@ -0,0 +1,277 @@
"""
User Repository
Repository for user operations with authentication-specific queries
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, func, desc, text
from datetime import datetime, timezone, timedelta
import structlog
from .base import AuthBaseRepository
from app.models.users import User
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
logger = structlog.get_logger()
class UserRepository(AuthBaseRepository):
"""Repository for user operations"""
def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 600):
super().__init__(model, session, cache_ttl)
async def create_user(self, user_data: Dict[str, Any]) -> User:
"""Create a new user with validation"""
try:
# Validate user data
validation_result = self._validate_auth_data(
user_data,
["email", "hashed_password", "full_name", "role"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid user data: {validation_result['errors']}")
# Check if user already exists
existing_user = await self.get_by_email(user_data["email"])
if existing_user:
raise DuplicateRecordError(f"User with email {user_data['email']} already exists")
# Create user
user = await self.create(user_data)
logger.info("User created successfully",
user_id=user.id,
email=user.email,
role=user.role)
return user
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create user",
email=user_data.get("email"),
error=str(e))
raise DatabaseError(f"Failed to create user: {str(e)}")
async def get_user_by_email(self, email: str) -> Optional[User]:
"""Get user by email address"""
return await self.get_by_email(email)
async def get_active_users(self, skip: int = 0, limit: int = 100) -> List[User]:
"""Get all active users"""
return await self.get_active_records(skip=skip, limit=limit)
async def authenticate_user(self, email: str, password: str) -> Optional[User]:
"""Authenticate user with email and plain password"""
try:
user = await self.get_by_email(email)
if not user:
logger.debug("User not found for authentication", email=email)
return None
if not user.is_active:
logger.debug("User account is inactive", email=email)
return None
# Verify password using security manager
from app.core.security import SecurityManager
if SecurityManager.verify_password(password, user.hashed_password):
# Update last login
await self.update_last_login(user.id)
logger.info("User authenticated successfully",
user_id=user.id,
email=email)
return user
logger.debug("Invalid password for user", email=email)
return None
except Exception as e:
logger.error("Authentication failed",
email=email,
error=str(e))
raise DatabaseError(f"Authentication failed: {str(e)}")
async def update_last_login(self, user_id: str) -> Optional[User]:
"""Update user's last login timestamp"""
try:
return await self.update(user_id, {
"last_login": datetime.now(timezone.utc)
})
except Exception as e:
logger.error("Failed to update last login",
user_id=user_id,
error=str(e))
# Don't raise here - last login update is not critical
return None
async def update_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> Optional[User]:
"""Update user profile information"""
try:
# Remove sensitive fields that shouldn't be updated via profile
profile_data.pop("id", None)
profile_data.pop("hashed_password", None)
profile_data.pop("created_at", None)
profile_data.pop("is_active", None)
# Validate email if being updated
if "email" in profile_data:
validation_result = self._validate_auth_data(
profile_data,
["email"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid profile data: {validation_result['errors']}")
# Check for email conflicts
existing_user = await self.get_by_email(profile_data["email"])
if existing_user and str(existing_user.id) != str(user_id):
raise DuplicateRecordError(f"Email {profile_data['email']} is already in use")
updated_user = await self.update(user_id, profile_data)
if updated_user:
logger.info("User profile updated",
user_id=user_id,
updated_fields=list(profile_data.keys()))
return updated_user
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to update user profile",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to update profile: {str(e)}")
async def change_password(self, user_id: str, new_password_hash: str) -> bool:
"""Change user password"""
try:
updated_user = await self.update(user_id, {
"hashed_password": new_password_hash
})
if updated_user:
logger.info("Password changed successfully", user_id=user_id)
return True
return False
except Exception as e:
logger.error("Failed to change password",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to change password: {str(e)}")
async def verify_user_email(self, user_id: str) -> Optional[User]:
"""Mark user email as verified"""
try:
return await self.update(user_id, {
"is_verified": True
})
except Exception as e:
logger.error("Failed to verify user email",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to verify email: {str(e)}")
async def deactivate_user(self, user_id: str) -> Optional[User]:
"""Deactivate user account"""
return await self.deactivate_record(user_id)
async def activate_user(self, user_id: str) -> Optional[User]:
"""Activate user account"""
return await self.activate_record(user_id)
async def get_users_by_role(self, role: str, skip: int = 0, limit: int = 100) -> List[User]:
"""Get users by role"""
try:
return await self.get_multi(
skip=skip,
limit=limit,
filters={"role": role, "is_active": True},
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get users by role",
role=role,
error=str(e))
raise DatabaseError(f"Failed to get users by role: {str(e)}")
async def search_users(self, search_term: str, skip: int = 0, limit: int = 50) -> List[User]:
"""Search users by email or full name"""
try:
return await self.search(
search_term=search_term,
search_fields=["email", "full_name"],
skip=skip,
limit=limit
)
except Exception as e:
logger.error("Failed to search users",
search_term=search_term,
error=str(e))
raise DatabaseError(f"Failed to search users: {str(e)}")
async def get_user_statistics(self) -> Dict[str, Any]:
"""Get user statistics"""
try:
# Get basic counts
total_users = await self.count()
active_users = await self.count(filters={"is_active": True})
verified_users = await self.count(filters={"is_verified": True})
# Get users by role using raw query
role_query = text("""
SELECT role, COUNT(*) as count
FROM users
WHERE is_active = true
GROUP BY role
ORDER BY count DESC
""")
result = await self.session.execute(role_query)
role_stats = {row.role: row.count for row in result.fetchall()}
# Recent activity (users created in last 30 days)
thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30)
recent_users_query = text("""
SELECT COUNT(*) as count
FROM users
WHERE created_at >= :thirty_days_ago
""")
recent_result = await self.session.execute(
recent_users_query,
{"thirty_days_ago": thirty_days_ago}
)
recent_users = recent_result.scalar() or 0
return {
"total_users": total_users,
"active_users": active_users,
"inactive_users": total_users - active_users,
"verified_users": verified_users,
"unverified_users": active_users - verified_users,
"recent_registrations": recent_users,
"users_by_role": role_stats
}
except Exception as e:
logger.error("Failed to get user statistics", error=str(e))
return {
"total_users": 0,
"active_users": 0,
"inactive_users": 0,
"verified_users": 0,
"unverified_users": 0,
"recent_registrations": 0,
"users_by_role": {}
}