REFACTOR - Database logic
This commit is contained in:
269
services/auth/app/repositories/token_repository.py
Normal file
269
services/auth/app/repositories/token_repository.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Token Repository
|
||||
Repository for refresh token operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import AuthBaseRepository
|
||||
from app.models.users import RefreshToken
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TokenRepository(AuthBaseRepository):
|
||||
"""Repository for refresh token operations"""
|
||||
|
||||
def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
# Tokens change frequently, shorter cache time
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def create_token(self, token_data: Dict[str, Any]) -> RefreshToken:
|
||||
"""Create a new refresh token from dictionary data"""
|
||||
return await self.create(token_data)
|
||||
|
||||
async def create_refresh_token(
|
||||
self,
|
||||
user_id: str,
|
||||
token: str,
|
||||
expires_at: datetime
|
||||
) -> RefreshToken:
|
||||
"""Create a new refresh token"""
|
||||
try:
|
||||
token_data = {
|
||||
"user_id": user_id,
|
||||
"token": token,
|
||||
"expires_at": expires_at,
|
||||
"is_revoked": False
|
||||
}
|
||||
|
||||
refresh_token = await self.create(token_data)
|
||||
|
||||
logger.debug("Refresh token created",
|
||||
user_id=user_id,
|
||||
token_id=refresh_token.id,
|
||||
expires_at=expires_at)
|
||||
|
||||
return refresh_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create refresh token",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create refresh token: {str(e)}")
|
||||
|
||||
async def get_token_by_value(self, token: str) -> Optional[RefreshToken]:
|
||||
"""Get refresh token by token value"""
|
||||
try:
|
||||
return await self.get_by_field("token", token)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get token by value", error=str(e))
|
||||
raise DatabaseError(f"Failed to get token: {str(e)}")
|
||||
|
||||
async def get_active_tokens_for_user(self, user_id: str) -> List[RefreshToken]:
|
||||
"""Get all active (non-revoked, non-expired) tokens for a user"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Use raw query for complex filtering
|
||||
query = text("""
|
||||
SELECT * FROM refresh_tokens
|
||||
WHERE user_id = :user_id
|
||||
AND is_revoked = false
|
||||
AND expires_at > :now
|
||||
ORDER BY created_at DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"user_id": user_id,
|
||||
"now": now
|
||||
})
|
||||
|
||||
# Convert rows to RefreshToken objects
|
||||
tokens = []
|
||||
for row in result.fetchall():
|
||||
token = RefreshToken(
|
||||
id=row.id,
|
||||
user_id=row.user_id,
|
||||
token=row.token,
|
||||
expires_at=row.expires_at,
|
||||
is_revoked=row.is_revoked,
|
||||
created_at=row.created_at,
|
||||
revoked_at=row.revoked_at
|
||||
)
|
||||
tokens.append(token)
|
||||
|
||||
return tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get active tokens for user",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get active tokens: {str(e)}")
|
||||
|
||||
async def revoke_token(self, token_id: str) -> Optional[RefreshToken]:
|
||||
"""Revoke a refresh token"""
|
||||
try:
|
||||
return await self.update(token_id, {
|
||||
"is_revoked": True,
|
||||
"revoked_at": datetime.now(timezone.utc)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Failed to revoke token",
|
||||
token_id=token_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to revoke token: {str(e)}")
|
||||
|
||||
async def revoke_all_user_tokens(self, user_id: str) -> int:
|
||||
"""Revoke all tokens for a user"""
|
||||
try:
|
||||
# Use bulk update for efficiency
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
query = text("""
|
||||
UPDATE refresh_tokens
|
||||
SET is_revoked = true, revoked_at = :revoked_at
|
||||
WHERE user_id = :user_id AND is_revoked = false
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"user_id": user_id,
|
||||
"revoked_at": now
|
||||
})
|
||||
|
||||
revoked_count = result.rowcount
|
||||
|
||||
logger.info("Revoked all user tokens",
|
||||
user_id=user_id,
|
||||
revoked_count=revoked_count)
|
||||
|
||||
return revoked_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to revoke all user tokens",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to revoke user tokens: {str(e)}")
|
||||
|
||||
async def is_token_valid(self, token: str) -> bool:
|
||||
"""Check if a token is valid (exists, not revoked, not expired)"""
|
||||
try:
|
||||
refresh_token = await self.get_token_by_value(token)
|
||||
|
||||
if not refresh_token:
|
||||
return False
|
||||
|
||||
if refresh_token.is_revoked:
|
||||
return False
|
||||
|
||||
if refresh_token.expires_at < datetime.now(timezone.utc):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate token", error=str(e))
|
||||
return False
|
||||
|
||||
async def cleanup_expired_tokens(self) -> int:
|
||||
"""Clean up expired refresh tokens"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Delete expired tokens
|
||||
query = text("""
|
||||
DELETE FROM refresh_tokens
|
||||
WHERE expires_at < :now
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {"now": now})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up expired tokens",
|
||||
deleted_count=deleted_count)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired tokens", error=str(e))
|
||||
raise DatabaseError(f"Token cleanup failed: {str(e)}")
|
||||
|
||||
async def cleanup_old_revoked_tokens(self, days_old: int = 30) -> int:
|
||||
"""Clean up old revoked tokens"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
|
||||
|
||||
query = text("""
|
||||
DELETE FROM refresh_tokens
|
||||
WHERE is_revoked = true
|
||||
AND revoked_at < :cutoff_date
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"cutoff_date": cutoff_date
|
||||
})
|
||||
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up old revoked tokens",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old revoked tokens",
|
||||
days_old=days_old,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Revoked token cleanup failed: {str(e)}")
|
||||
|
||||
async def get_token_statistics(self) -> Dict[str, Any]:
|
||||
"""Get token statistics"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Get counts with raw queries
|
||||
stats_query = text("""
|
||||
SELECT
|
||||
COUNT(*) as total_tokens,
|
||||
COUNT(CASE WHEN is_revoked = false AND expires_at > :now THEN 1 END) as active_tokens,
|
||||
COUNT(CASE WHEN is_revoked = true THEN 1 END) as revoked_tokens,
|
||||
COUNT(CASE WHEN expires_at <= :now THEN 1 END) as expired_tokens,
|
||||
COUNT(DISTINCT user_id) as users_with_tokens
|
||||
FROM refresh_tokens
|
||||
""")
|
||||
|
||||
result = await self.session.execute(stats_query, {"now": now})
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
return {
|
||||
"total_tokens": row.total_tokens,
|
||||
"active_tokens": row.active_tokens,
|
||||
"revoked_tokens": row.revoked_tokens,
|
||||
"expired_tokens": row.expired_tokens,
|
||||
"users_with_tokens": row.users_with_tokens
|
||||
}
|
||||
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"active_tokens": 0,
|
||||
"revoked_tokens": 0,
|
||||
"expired_tokens": 0,
|
||||
"users_with_tokens": 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get token statistics", error=str(e))
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"active_tokens": 0,
|
||||
"revoked_tokens": 0,
|
||||
"expired_tokens": 0,
|
||||
"users_with_tokens": 0
|
||||
}
|
||||
Reference in New Issue
Block a user