""" Token Repository Repository for refresh token operations """ from typing import Optional, List, Dict, Any from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_, text from datetime import datetime, timezone, timedelta import structlog from .base import AuthBaseRepository from app.models.tokens import RefreshToken from shared.database.exceptions import DatabaseError logger = structlog.get_logger() class TokenRepository(AuthBaseRepository): """Repository for refresh token operations""" def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 300): # Tokens change frequently, shorter cache time super().__init__(model, session, cache_ttl) async def create_token(self, token_data: Dict[str, Any]) -> RefreshToken: """Create a new refresh token from dictionary data""" return await self.create(token_data) async def create_refresh_token( self, user_id: str, token: str, expires_at: datetime ) -> RefreshToken: """Create a new refresh token""" try: token_data = { "user_id": user_id, "token": token, "expires_at": expires_at, "is_revoked": False } refresh_token = await self.create(token_data) logger.debug("Refresh token created", user_id=user_id, token_id=refresh_token.id, expires_at=expires_at) return refresh_token except Exception as e: logger.error("Failed to create refresh token", user_id=user_id, error=str(e)) raise DatabaseError(f"Failed to create refresh token: {str(e)}") async def get_token_by_value(self, token: str) -> Optional[RefreshToken]: """Get refresh token by token value""" try: return await self.get_by_field("token", token) except Exception as e: logger.error("Failed to get token by value", error=str(e)) raise DatabaseError(f"Failed to get token: {str(e)}") async def get_active_tokens_for_user(self, user_id: str) -> List[RefreshToken]: """Get all active (non-revoked, non-expired) tokens for a user""" try: now = datetime.now(timezone.utc) # Use raw query for complex filtering query = text(""" SELECT * FROM refresh_tokens WHERE user_id = :user_id AND is_revoked = false AND expires_at > :now ORDER BY created_at DESC """) result = await self.session.execute(query, { "user_id": user_id, "now": now }) # Convert rows to RefreshToken objects tokens = [] for row in result.fetchall(): token = RefreshToken( id=row.id, user_id=row.user_id, token=row.token, expires_at=row.expires_at, is_revoked=row.is_revoked, created_at=row.created_at, revoked_at=row.revoked_at ) tokens.append(token) return tokens except Exception as e: logger.error("Failed to get active tokens for user", user_id=user_id, error=str(e)) raise DatabaseError(f"Failed to get active tokens: {str(e)}") async def revoke_token(self, token_id: str) -> Optional[RefreshToken]: """Revoke a refresh token""" try: return await self.update(token_id, { "is_revoked": True, "revoked_at": datetime.now(timezone.utc) }) except Exception as e: logger.error("Failed to revoke token", token_id=token_id, error=str(e)) raise DatabaseError(f"Failed to revoke token: {str(e)}") async def revoke_all_user_tokens(self, user_id: str) -> int: """Revoke all tokens for a user""" try: # Use bulk update for efficiency now = datetime.now(timezone.utc) query = text(""" UPDATE refresh_tokens SET is_revoked = true, revoked_at = :revoked_at WHERE user_id = :user_id AND is_revoked = false """) result = await self.session.execute(query, { "user_id": user_id, "revoked_at": now }) revoked_count = result.rowcount logger.info("Revoked all user tokens", user_id=user_id, revoked_count=revoked_count) return revoked_count except Exception as e: logger.error("Failed to revoke all user tokens", user_id=user_id, error=str(e)) raise DatabaseError(f"Failed to revoke user tokens: {str(e)}") async def is_token_valid(self, token: str) -> bool: """Check if a token is valid (exists, not revoked, not expired)""" try: refresh_token = await self.get_token_by_value(token) if not refresh_token: return False if refresh_token.is_revoked: return False if refresh_token.expires_at < datetime.now(timezone.utc): return False return True except Exception as e: logger.error("Failed to validate token", error=str(e)) return False async def validate_refresh_token(self, token: str, user_id: str) -> bool: """Validate refresh token for a specific user""" try: refresh_token = await self.get_token_by_value(token) if not refresh_token: logger.debug("Refresh token not found", token_prefix=token[:10] + "...") return False # Convert both to strings for comparison to handle UUID vs string mismatch token_user_id = str(refresh_token.user_id) expected_user_id = str(user_id) if token_user_id != expected_user_id: logger.warning("Refresh token user_id mismatch", expected_user_id=expected_user_id, actual_user_id=token_user_id) return False if refresh_token.is_revoked: logger.debug("Refresh token is revoked", user_id=user_id) return False if refresh_token.expires_at < datetime.now(timezone.utc): logger.debug("Refresh token is expired", user_id=user_id) return False logger.debug("Refresh token is valid", user_id=user_id) return True except Exception as e: logger.error("Failed to validate refresh token", user_id=user_id, error=str(e)) return False async def cleanup_expired_tokens(self) -> int: """Clean up expired refresh tokens""" try: now = datetime.now(timezone.utc) # Delete expired tokens query = text(""" DELETE FROM refresh_tokens WHERE expires_at < :now """) result = await self.session.execute(query, {"now": now}) deleted_count = result.rowcount logger.info("Cleaned up expired tokens", deleted_count=deleted_count) return deleted_count except Exception as e: logger.error("Failed to cleanup expired tokens", error=str(e)) raise DatabaseError(f"Token cleanup failed: {str(e)}") async def cleanup_old_revoked_tokens(self, days_old: int = 30) -> int: """Clean up old revoked tokens""" try: cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old) query = text(""" DELETE FROM refresh_tokens WHERE is_revoked = true AND revoked_at < :cutoff_date """) result = await self.session.execute(query, { "cutoff_date": cutoff_date }) deleted_count = result.rowcount logger.info("Cleaned up old revoked tokens", deleted_count=deleted_count, days_old=days_old) return deleted_count except Exception as e: logger.error("Failed to cleanup old revoked tokens", days_old=days_old, error=str(e)) raise DatabaseError(f"Revoked token cleanup failed: {str(e)}") async def get_token_statistics(self) -> Dict[str, Any]: """Get token statistics""" try: now = datetime.now(timezone.utc) # Get counts with raw queries stats_query = text(""" SELECT COUNT(*) as total_tokens, COUNT(CASE WHEN is_revoked = false AND expires_at > :now THEN 1 END) as active_tokens, COUNT(CASE WHEN is_revoked = true THEN 1 END) as revoked_tokens, COUNT(CASE WHEN expires_at <= :now THEN 1 END) as expired_tokens, COUNT(DISTINCT user_id) as users_with_tokens FROM refresh_tokens """) result = await self.session.execute(stats_query, {"now": now}) row = result.fetchone() if row: return { "total_tokens": row.total_tokens, "active_tokens": row.active_tokens, "revoked_tokens": row.revoked_tokens, "expired_tokens": row.expired_tokens, "users_with_tokens": row.users_with_tokens } return { "total_tokens": 0, "active_tokens": 0, "revoked_tokens": 0, "expired_tokens": 0, "users_with_tokens": 0 } except Exception as e: logger.error("Failed to get token statistics", error=str(e)) return { "total_tokens": 0, "active_tokens": 0, "revoked_tokens": 0, "expired_tokens": 0, "users_with_tokens": 0 }