Files
bakery-ia/services/auth/app/repositories/token_repository.py
2025-09-17 16:06:30 +02:00

305 lines
11 KiB
Python

"""
Token Repository
Repository for refresh token operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text
from datetime import datetime, timezone, timedelta
import structlog
from .base import AuthBaseRepository
from app.models.users import RefreshToken
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class TokenRepository(AuthBaseRepository):
"""Repository for refresh token operations"""
def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Tokens change frequently, shorter cache time
super().__init__(model, session, cache_ttl)
async def create_token(self, token_data: Dict[str, Any]) -> RefreshToken:
"""Create a new refresh token from dictionary data"""
return await self.create(token_data)
async def create_refresh_token(
self,
user_id: str,
token: str,
expires_at: datetime
) -> RefreshToken:
"""Create a new refresh token"""
try:
token_data = {
"user_id": user_id,
"token": token,
"expires_at": expires_at,
"is_revoked": False
}
refresh_token = await self.create(token_data)
logger.debug("Refresh token created",
user_id=user_id,
token_id=refresh_token.id,
expires_at=expires_at)
return refresh_token
except Exception as e:
logger.error("Failed to create refresh token",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to create refresh token: {str(e)}")
async def get_token_by_value(self, token: str) -> Optional[RefreshToken]:
"""Get refresh token by token value"""
try:
return await self.get_by_field("token", token)
except Exception as e:
logger.error("Failed to get token by value", error=str(e))
raise DatabaseError(f"Failed to get token: {str(e)}")
async def get_active_tokens_for_user(self, user_id: str) -> List[RefreshToken]:
"""Get all active (non-revoked, non-expired) tokens for a user"""
try:
now = datetime.now(timezone.utc)
# Use raw query for complex filtering
query = text("""
SELECT * FROM refresh_tokens
WHERE user_id = :user_id
AND is_revoked = false
AND expires_at > :now
ORDER BY created_at DESC
""")
result = await self.session.execute(query, {
"user_id": user_id,
"now": now
})
# Convert rows to RefreshToken objects
tokens = []
for row in result.fetchall():
token = RefreshToken(
id=row.id,
user_id=row.user_id,
token=row.token,
expires_at=row.expires_at,
is_revoked=row.is_revoked,
created_at=row.created_at,
revoked_at=row.revoked_at
)
tokens.append(token)
return tokens
except Exception as e:
logger.error("Failed to get active tokens for user",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to get active tokens: {str(e)}")
async def revoke_token(self, token_id: str) -> Optional[RefreshToken]:
"""Revoke a refresh token"""
try:
return await self.update(token_id, {
"is_revoked": True,
"revoked_at": datetime.now(timezone.utc)
})
except Exception as e:
logger.error("Failed to revoke token",
token_id=token_id,
error=str(e))
raise DatabaseError(f"Failed to revoke token: {str(e)}")
async def revoke_all_user_tokens(self, user_id: str) -> int:
"""Revoke all tokens for a user"""
try:
# Use bulk update for efficiency
now = datetime.now(timezone.utc)
query = text("""
UPDATE refresh_tokens
SET is_revoked = true, revoked_at = :revoked_at
WHERE user_id = :user_id AND is_revoked = false
""")
result = await self.session.execute(query, {
"user_id": user_id,
"revoked_at": now
})
revoked_count = result.rowcount
logger.info("Revoked all user tokens",
user_id=user_id,
revoked_count=revoked_count)
return revoked_count
except Exception as e:
logger.error("Failed to revoke all user tokens",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to revoke user tokens: {str(e)}")
async def is_token_valid(self, token: str) -> bool:
"""Check if a token is valid (exists, not revoked, not expired)"""
try:
refresh_token = await self.get_token_by_value(token)
if not refresh_token:
return False
if refresh_token.is_revoked:
return False
if refresh_token.expires_at < datetime.now(timezone.utc):
return False
return True
except Exception as e:
logger.error("Failed to validate token", error=str(e))
return False
async def validate_refresh_token(self, token: str, user_id: str) -> bool:
"""Validate refresh token for a specific user"""
try:
refresh_token = await self.get_token_by_value(token)
if not refresh_token:
logger.debug("Refresh token not found", token_prefix=token[:10] + "...")
return False
# Convert both to strings for comparison to handle UUID vs string mismatch
token_user_id = str(refresh_token.user_id)
expected_user_id = str(user_id)
if token_user_id != expected_user_id:
logger.warning("Refresh token user_id mismatch",
expected_user_id=expected_user_id,
actual_user_id=token_user_id)
return False
if refresh_token.is_revoked:
logger.debug("Refresh token is revoked", user_id=user_id)
return False
if refresh_token.expires_at < datetime.now(timezone.utc):
logger.debug("Refresh token is expired", user_id=user_id)
return False
logger.debug("Refresh token is valid", user_id=user_id)
return True
except Exception as e:
logger.error("Failed to validate refresh token",
user_id=user_id,
error=str(e))
return False
async def cleanup_expired_tokens(self) -> int:
"""Clean up expired refresh tokens"""
try:
now = datetime.now(timezone.utc)
# Delete expired tokens
query = text("""
DELETE FROM refresh_tokens
WHERE expires_at < :now
""")
result = await self.session.execute(query, {"now": now})
deleted_count = result.rowcount
logger.info("Cleaned up expired tokens",
deleted_count=deleted_count)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup expired tokens", error=str(e))
raise DatabaseError(f"Token cleanup failed: {str(e)}")
async def cleanup_old_revoked_tokens(self, days_old: int = 30) -> int:
"""Clean up old revoked tokens"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
query = text("""
DELETE FROM refresh_tokens
WHERE is_revoked = true
AND revoked_at < :cutoff_date
""")
result = await self.session.execute(query, {
"cutoff_date": cutoff_date
})
deleted_count = result.rowcount
logger.info("Cleaned up old revoked tokens",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old revoked tokens",
days_old=days_old,
error=str(e))
raise DatabaseError(f"Revoked token cleanup failed: {str(e)}")
async def get_token_statistics(self) -> Dict[str, Any]:
"""Get token statistics"""
try:
now = datetime.now(timezone.utc)
# Get counts with raw queries
stats_query = text("""
SELECT
COUNT(*) as total_tokens,
COUNT(CASE WHEN is_revoked = false AND expires_at > :now THEN 1 END) as active_tokens,
COUNT(CASE WHEN is_revoked = true THEN 1 END) as revoked_tokens,
COUNT(CASE WHEN expires_at <= :now THEN 1 END) as expired_tokens,
COUNT(DISTINCT user_id) as users_with_tokens
FROM refresh_tokens
""")
result = await self.session.execute(stats_query, {"now": now})
row = result.fetchone()
if row:
return {
"total_tokens": row.total_tokens,
"active_tokens": row.active_tokens,
"revoked_tokens": row.revoked_tokens,
"expired_tokens": row.expired_tokens,
"users_with_tokens": row.users_with_tokens
}
return {
"total_tokens": 0,
"active_tokens": 0,
"revoked_tokens": 0,
"expired_tokens": 0,
"users_with_tokens": 0
}
except Exception as e:
logger.error("Failed to get token statistics", error=str(e))
return {
"total_tokens": 0,
"active_tokens": 0,
"revoked_tokens": 0,
"expired_tokens": 0,
"users_with_tokens": 0
}