269 lines
9.5 KiB
Python
269 lines
9.5 KiB
Python
"""
|
|
Token Repository
|
|
Repository for refresh token operations
|
|
"""
|
|
|
|
from typing import Optional, List, Dict, Any
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_, text
|
|
from datetime import datetime, timezone, timedelta
|
|
import structlog
|
|
|
|
from .base import AuthBaseRepository
|
|
from app.models.users import RefreshToken
|
|
from shared.database.exceptions import DatabaseError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class TokenRepository(AuthBaseRepository):
|
|
"""Repository for refresh token operations"""
|
|
|
|
def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
|
# Tokens change frequently, shorter cache time
|
|
super().__init__(model, session, cache_ttl)
|
|
|
|
async def create_token(self, token_data: Dict[str, Any]) -> RefreshToken:
|
|
"""Create a new refresh token from dictionary data"""
|
|
return await self.create(token_data)
|
|
|
|
async def create_refresh_token(
|
|
self,
|
|
user_id: str,
|
|
token: str,
|
|
expires_at: datetime
|
|
) -> RefreshToken:
|
|
"""Create a new refresh token"""
|
|
try:
|
|
token_data = {
|
|
"user_id": user_id,
|
|
"token": token,
|
|
"expires_at": expires_at,
|
|
"is_revoked": False
|
|
}
|
|
|
|
refresh_token = await self.create(token_data)
|
|
|
|
logger.debug("Refresh token created",
|
|
user_id=user_id,
|
|
token_id=refresh_token.id,
|
|
expires_at=expires_at)
|
|
|
|
return refresh_token
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to create refresh token",
|
|
user_id=user_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to create refresh token: {str(e)}")
|
|
|
|
async def get_token_by_value(self, token: str) -> Optional[RefreshToken]:
|
|
"""Get refresh token by token value"""
|
|
try:
|
|
return await self.get_by_field("token", token)
|
|
except Exception as e:
|
|
logger.error("Failed to get token by value", error=str(e))
|
|
raise DatabaseError(f"Failed to get token: {str(e)}")
|
|
|
|
async def get_active_tokens_for_user(self, user_id: str) -> List[RefreshToken]:
|
|
"""Get all active (non-revoked, non-expired) tokens for a user"""
|
|
try:
|
|
now = datetime.now(timezone.utc)
|
|
|
|
# Use raw query for complex filtering
|
|
query = text("""
|
|
SELECT * FROM refresh_tokens
|
|
WHERE user_id = :user_id
|
|
AND is_revoked = false
|
|
AND expires_at > :now
|
|
ORDER BY created_at DESC
|
|
""")
|
|
|
|
result = await self.session.execute(query, {
|
|
"user_id": user_id,
|
|
"now": now
|
|
})
|
|
|
|
# Convert rows to RefreshToken objects
|
|
tokens = []
|
|
for row in result.fetchall():
|
|
token = RefreshToken(
|
|
id=row.id,
|
|
user_id=row.user_id,
|
|
token=row.token,
|
|
expires_at=row.expires_at,
|
|
is_revoked=row.is_revoked,
|
|
created_at=row.created_at,
|
|
revoked_at=row.revoked_at
|
|
)
|
|
tokens.append(token)
|
|
|
|
return tokens
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get active tokens for user",
|
|
user_id=user_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get active tokens: {str(e)}")
|
|
|
|
async def revoke_token(self, token_id: str) -> Optional[RefreshToken]:
|
|
"""Revoke a refresh token"""
|
|
try:
|
|
return await self.update(token_id, {
|
|
"is_revoked": True,
|
|
"revoked_at": datetime.now(timezone.utc)
|
|
})
|
|
except Exception as e:
|
|
logger.error("Failed to revoke token",
|
|
token_id=token_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to revoke token: {str(e)}")
|
|
|
|
async def revoke_all_user_tokens(self, user_id: str) -> int:
|
|
"""Revoke all tokens for a user"""
|
|
try:
|
|
# Use bulk update for efficiency
|
|
now = datetime.now(timezone.utc)
|
|
|
|
query = text("""
|
|
UPDATE refresh_tokens
|
|
SET is_revoked = true, revoked_at = :revoked_at
|
|
WHERE user_id = :user_id AND is_revoked = false
|
|
""")
|
|
|
|
result = await self.session.execute(query, {
|
|
"user_id": user_id,
|
|
"revoked_at": now
|
|
})
|
|
|
|
revoked_count = result.rowcount
|
|
|
|
logger.info("Revoked all user tokens",
|
|
user_id=user_id,
|
|
revoked_count=revoked_count)
|
|
|
|
return revoked_count
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to revoke all user tokens",
|
|
user_id=user_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to revoke user tokens: {str(e)}")
|
|
|
|
async def is_token_valid(self, token: str) -> bool:
|
|
"""Check if a token is valid (exists, not revoked, not expired)"""
|
|
try:
|
|
refresh_token = await self.get_token_by_value(token)
|
|
|
|
if not refresh_token:
|
|
return False
|
|
|
|
if refresh_token.is_revoked:
|
|
return False
|
|
|
|
if refresh_token.expires_at < datetime.now(timezone.utc):
|
|
return False
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to validate token", error=str(e))
|
|
return False
|
|
|
|
async def cleanup_expired_tokens(self) -> int:
|
|
"""Clean up expired refresh tokens"""
|
|
try:
|
|
now = datetime.now(timezone.utc)
|
|
|
|
# Delete expired tokens
|
|
query = text("""
|
|
DELETE FROM refresh_tokens
|
|
WHERE expires_at < :now
|
|
""")
|
|
|
|
result = await self.session.execute(query, {"now": now})
|
|
deleted_count = result.rowcount
|
|
|
|
logger.info("Cleaned up expired tokens",
|
|
deleted_count=deleted_count)
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to cleanup expired tokens", error=str(e))
|
|
raise DatabaseError(f"Token cleanup failed: {str(e)}")
|
|
|
|
async def cleanup_old_revoked_tokens(self, days_old: int = 30) -> int:
|
|
"""Clean up old revoked tokens"""
|
|
try:
|
|
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
|
|
|
|
query = text("""
|
|
DELETE FROM refresh_tokens
|
|
WHERE is_revoked = true
|
|
AND revoked_at < :cutoff_date
|
|
""")
|
|
|
|
result = await self.session.execute(query, {
|
|
"cutoff_date": cutoff_date
|
|
})
|
|
|
|
deleted_count = result.rowcount
|
|
|
|
logger.info("Cleaned up old revoked tokens",
|
|
deleted_count=deleted_count,
|
|
days_old=days_old)
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to cleanup old revoked tokens",
|
|
days_old=days_old,
|
|
error=str(e))
|
|
raise DatabaseError(f"Revoked token cleanup failed: {str(e)}")
|
|
|
|
async def get_token_statistics(self) -> Dict[str, Any]:
|
|
"""Get token statistics"""
|
|
try:
|
|
now = datetime.now(timezone.utc)
|
|
|
|
# Get counts with raw queries
|
|
stats_query = text("""
|
|
SELECT
|
|
COUNT(*) as total_tokens,
|
|
COUNT(CASE WHEN is_revoked = false AND expires_at > :now THEN 1 END) as active_tokens,
|
|
COUNT(CASE WHEN is_revoked = true THEN 1 END) as revoked_tokens,
|
|
COUNT(CASE WHEN expires_at <= :now THEN 1 END) as expired_tokens,
|
|
COUNT(DISTINCT user_id) as users_with_tokens
|
|
FROM refresh_tokens
|
|
""")
|
|
|
|
result = await self.session.execute(stats_query, {"now": now})
|
|
row = result.fetchone()
|
|
|
|
if row:
|
|
return {
|
|
"total_tokens": row.total_tokens,
|
|
"active_tokens": row.active_tokens,
|
|
"revoked_tokens": row.revoked_tokens,
|
|
"expired_tokens": row.expired_tokens,
|
|
"users_with_tokens": row.users_with_tokens
|
|
}
|
|
|
|
return {
|
|
"total_tokens": 0,
|
|
"active_tokens": 0,
|
|
"revoked_tokens": 0,
|
|
"expired_tokens": 0,
|
|
"users_with_tokens": 0
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get token statistics", error=str(e))
|
|
return {
|
|
"total_tokens": 0,
|
|
"active_tokens": 0,
|
|
"revoked_tokens": 0,
|
|
"expired_tokens": 0,
|
|
"users_with_tokens": 0
|
|
} |