Initial commit - production deployment
This commit is contained in:
16
services/auth/app/repositories/__init__.py
Normal file
16
services/auth/app/repositories/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Auth Service Repositories
|
||||
Repository implementations for authentication service
|
||||
"""
|
||||
|
||||
from .base import AuthBaseRepository
|
||||
from .user_repository import UserRepository
|
||||
from .token_repository import TokenRepository
|
||||
from .onboarding_repository import OnboardingRepository
|
||||
|
||||
__all__ = [
|
||||
"AuthBaseRepository",
|
||||
"UserRepository",
|
||||
"TokenRepository",
|
||||
"OnboardingRepository"
|
||||
]
|
||||
101
services/auth/app/repositories/base.py
Normal file
101
services/auth/app/repositories/base.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Base Repository for Auth Service
|
||||
Service-specific repository base class with auth service utilities
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
|
||||
from shared.database.repository import BaseRepository
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class AuthBaseRepository(BaseRepository):
|
||||
"""Base repository for auth service with common auth operations"""
|
||||
|
||||
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Auth data benefits from longer caching (10 minutes)
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get active records (if model has is_active field)"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional:
|
||||
"""Get record by email (if model has email field)"""
|
||||
if hasattr(self.model, 'email'):
|
||||
return await self.get_by_field("email", email)
|
||||
return None
|
||||
|
||||
async def get_by_username(self, username: str) -> Optional:
|
||||
"""Get record by username (if model has username field)"""
|
||||
if hasattr(self.model, 'username'):
|
||||
return await self.get_by_field("username", username)
|
||||
return None
|
||||
|
||||
async def deactivate_record(self, record_id: Any) -> Optional:
|
||||
"""Deactivate a record instead of deleting it"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": False})
|
||||
return await self.delete(record_id)
|
||||
|
||||
async def activate_record(self, record_id: Any) -> Optional:
|
||||
"""Activate a record"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": True})
|
||||
return await self.get_by_id(record_id)
|
||||
|
||||
async def cleanup_expired_records(self, field_name: str = "expires_at") -> int:
|
||||
"""Clean up expired records (for tokens, sessions, etc.)"""
|
||||
try:
|
||||
if not hasattr(self.model, field_name):
|
||||
logger.warning(f"Model {self.model.__name__} has no {field_name} field for cleanup")
|
||||
return 0
|
||||
|
||||
# This would need custom implementation with raw SQL for date comparison
|
||||
# For now, return 0 to indicate no cleanup performed
|
||||
logger.info(f"Cleanup requested for {self.model.__name__} but not implemented")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired records",
|
||||
model=self.model.__name__,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
def _validate_auth_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
|
||||
"""Validate authentication-related data"""
|
||||
errors = []
|
||||
|
||||
for field in required_fields:
|
||||
if field not in data or not data[field]:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
|
||||
# Validate email format if present
|
||||
if "email" in data and data["email"]:
|
||||
email = data["email"]
|
||||
if "@" not in email or "." not in email.split("@")[-1]:
|
||||
errors.append("Invalid email format")
|
||||
|
||||
# Validate password strength if present
|
||||
if "password" in data and data["password"]:
|
||||
password = data["password"]
|
||||
if len(password) < 8:
|
||||
errors.append("Password must be at least 8 characters long")
|
||||
|
||||
return {
|
||||
"is_valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
110
services/auth/app/repositories/deletion_job_repository.py
Normal file
110
services/auth/app/repositories/deletion_job_repository.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Deletion Job Repository
|
||||
Database operations for deletion job persistence
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from sqlalchemy import select, and_, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import structlog
|
||||
|
||||
from app.models.deletion_job import DeletionJob
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class DeletionJobRepository:
|
||||
"""Repository for deletion job database operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create(self, deletion_job: DeletionJob) -> DeletionJob:
|
||||
"""Create a new deletion job record"""
|
||||
try:
|
||||
self.session.add(deletion_job)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(deletion_job)
|
||||
return deletion_job
|
||||
except Exception as e:
|
||||
logger.error("Failed to create deletion job", error=str(e))
|
||||
raise
|
||||
|
||||
async def get_by_job_id(self, job_id: str) -> Optional[DeletionJob]:
|
||||
"""Get deletion job by job_id"""
|
||||
try:
|
||||
query = select(DeletionJob).where(DeletionJob.job_id == job_id)
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Failed to get deletion job", error=str(e), job_id=job_id)
|
||||
raise
|
||||
|
||||
async def get_by_id(self, id: UUID) -> Optional[DeletionJob]:
|
||||
"""Get deletion job by database ID"""
|
||||
try:
|
||||
return await self.session.get(DeletionJob, id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get deletion job by ID", error=str(e), id=str(id))
|
||||
raise
|
||||
|
||||
async def list_by_tenant(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
status: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[DeletionJob]:
|
||||
"""List deletion jobs for a tenant"""
|
||||
try:
|
||||
query = select(DeletionJob).where(DeletionJob.tenant_id == tenant_id)
|
||||
|
||||
if status:
|
||||
query = query.where(DeletionJob.status == status)
|
||||
|
||||
query = query.order_by(desc(DeletionJob.started_at)).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error("Failed to list deletion jobs", error=str(e), tenant_id=str(tenant_id))
|
||||
raise
|
||||
|
||||
async def list_all(
|
||||
self,
|
||||
status: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[DeletionJob]:
|
||||
"""List all deletion jobs with optional status filter"""
|
||||
try:
|
||||
query = select(DeletionJob)
|
||||
|
||||
if status:
|
||||
query = query.where(DeletionJob.status == status)
|
||||
|
||||
query = query.order_by(desc(DeletionJob.started_at)).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error("Failed to list all deletion jobs", error=str(e))
|
||||
raise
|
||||
|
||||
async def update(self, deletion_job: DeletionJob) -> DeletionJob:
|
||||
"""Update a deletion job record"""
|
||||
try:
|
||||
await self.session.flush()
|
||||
await self.session.refresh(deletion_job)
|
||||
return deletion_job
|
||||
except Exception as e:
|
||||
logger.error("Failed to update deletion job", error=str(e))
|
||||
raise
|
||||
|
||||
async def delete(self, deletion_job: DeletionJob) -> None:
|
||||
"""Delete a deletion job record"""
|
||||
try:
|
||||
await self.session.delete(deletion_job)
|
||||
await self.session.flush()
|
||||
except Exception as e:
|
||||
logger.error("Failed to delete deletion job", error=str(e))
|
||||
raise
|
||||
313
services/auth/app/repositories/onboarding_repository.py
Normal file
313
services/auth/app/repositories/onboarding_repository.py
Normal file
@@ -0,0 +1,313 @@
|
||||
# services/auth/app/repositories/onboarding_repository.py
|
||||
"""
|
||||
Onboarding Repository for database operations
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, delete, and_
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
|
||||
from app.models.onboarding import UserOnboardingProgress, UserOnboardingSummary
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class OnboardingRepository:
|
||||
"""Repository for onboarding progress operations"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_user_progress_steps(self, user_id: str) -> List[UserOnboardingProgress]:
|
||||
"""Get all onboarding steps for a user"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(UserOnboardingProgress)
|
||||
.where(UserOnboardingProgress.user_id == user_id)
|
||||
.order_by(UserOnboardingProgress.created_at)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user progress steps for {user_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_user_step(self, user_id: str, step_name: str) -> Optional[UserOnboardingProgress]:
|
||||
"""Get a specific step for a user"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(UserOnboardingProgress)
|
||||
.where(
|
||||
and_(
|
||||
UserOnboardingProgress.user_id == user_id,
|
||||
UserOnboardingProgress.step_name == step_name
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting step {step_name} for user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def upsert_user_step(
|
||||
self,
|
||||
user_id: str,
|
||||
step_name: str,
|
||||
completed: bool,
|
||||
step_data: Dict[str, Any] = None,
|
||||
auto_commit: bool = True
|
||||
) -> UserOnboardingProgress:
|
||||
"""Insert or update a user's onboarding step
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
step_name: Name of the step
|
||||
completed: Whether the step is completed
|
||||
step_data: Additional data for the step
|
||||
auto_commit: Whether to auto-commit (set to False when used within UnitOfWork)
|
||||
"""
|
||||
try:
|
||||
completed_at = datetime.now(timezone.utc) if completed else None
|
||||
step_data = step_data or {}
|
||||
|
||||
# Use PostgreSQL UPSERT (INSERT ... ON CONFLICT ... DO UPDATE)
|
||||
stmt = insert(UserOnboardingProgress).values(
|
||||
user_id=user_id,
|
||||
step_name=step_name,
|
||||
completed=completed,
|
||||
completed_at=completed_at,
|
||||
step_data=step_data,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# On conflict, update the existing record
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=['user_id', 'step_name'],
|
||||
set_=dict(
|
||||
completed=stmt.excluded.completed,
|
||||
completed_at=stmt.excluded.completed_at,
|
||||
step_data=stmt.excluded.step_data,
|
||||
updated_at=stmt.excluded.updated_at
|
||||
)
|
||||
)
|
||||
|
||||
# Return the updated record
|
||||
stmt = stmt.returning(UserOnboardingProgress)
|
||||
result = await self.db.execute(stmt)
|
||||
|
||||
# Only commit if auto_commit is True (not within a UnitOfWork)
|
||||
if auto_commit:
|
||||
await self.db.commit()
|
||||
else:
|
||||
# Flush to ensure the statement is executed
|
||||
await self.db.flush()
|
||||
|
||||
return result.scalars().first()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error upserting step {step_name} for user {user_id}: {e}")
|
||||
if auto_commit:
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def get_user_summary(self, user_id: str) -> Optional[UserOnboardingSummary]:
|
||||
"""Get user's onboarding summary"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(UserOnboardingSummary)
|
||||
.where(UserOnboardingSummary.user_id == user_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting onboarding summary for user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def upsert_user_summary(
|
||||
self,
|
||||
user_id: str,
|
||||
current_step: str,
|
||||
next_step: Optional[str],
|
||||
completion_percentage: float,
|
||||
fully_completed: bool,
|
||||
steps_completed_count: str
|
||||
) -> UserOnboardingSummary:
|
||||
"""Insert or update user's onboarding summary"""
|
||||
try:
|
||||
# Use PostgreSQL UPSERT
|
||||
stmt = insert(UserOnboardingSummary).values(
|
||||
user_id=user_id,
|
||||
current_step=current_step,
|
||||
next_step=next_step,
|
||||
completion_percentage=str(completion_percentage),
|
||||
fully_completed=fully_completed,
|
||||
steps_completed_count=steps_completed_count,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
last_activity_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# On conflict, update the existing record
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=['user_id'],
|
||||
set_=dict(
|
||||
current_step=stmt.excluded.current_step,
|
||||
next_step=stmt.excluded.next_step,
|
||||
completion_percentage=stmt.excluded.completion_percentage,
|
||||
fully_completed=stmt.excluded.fully_completed,
|
||||
steps_completed_count=stmt.excluded.steps_completed_count,
|
||||
updated_at=stmt.excluded.updated_at,
|
||||
last_activity_at=stmt.excluded.last_activity_at
|
||||
)
|
||||
)
|
||||
|
||||
# Return the updated record
|
||||
stmt = stmt.returning(UserOnboardingSummary)
|
||||
result = await self.db.execute(stmt)
|
||||
await self.db.commit()
|
||||
|
||||
return result.scalars().first()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error upserting summary for user {user_id}: {e}")
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def delete_user_progress(self, user_id: str) -> bool:
|
||||
"""Delete all onboarding progress for a user"""
|
||||
try:
|
||||
# Delete steps
|
||||
await self.db.execute(
|
||||
delete(UserOnboardingProgress)
|
||||
.where(UserOnboardingProgress.user_id == user_id)
|
||||
)
|
||||
|
||||
# Delete summary
|
||||
await self.db.execute(
|
||||
delete(UserOnboardingSummary)
|
||||
.where(UserOnboardingSummary.user_id == user_id)
|
||||
)
|
||||
|
||||
await self.db.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting progress for user {user_id}: {e}")
|
||||
await self.db.rollback()
|
||||
return False
|
||||
|
||||
async def save_step_data(
|
||||
self,
|
||||
user_id: str,
|
||||
step_name: str,
|
||||
step_data: Dict[str, Any],
|
||||
auto_commit: bool = True
|
||||
) -> UserOnboardingProgress:
|
||||
"""Save data for a specific step without marking it as completed
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
step_name: Name of the step
|
||||
step_data: Data to save
|
||||
auto_commit: Whether to auto-commit (set to False when used within UnitOfWork)
|
||||
"""
|
||||
try:
|
||||
# Get existing step or create new one
|
||||
existing_step = await self.get_user_step(user_id, step_name)
|
||||
|
||||
if existing_step:
|
||||
# Update existing step data (merge with existing data)
|
||||
merged_data = {**(existing_step.step_data or {}), **step_data}
|
||||
|
||||
stmt = update(UserOnboardingProgress).where(
|
||||
and_(
|
||||
UserOnboardingProgress.user_id == user_id,
|
||||
UserOnboardingProgress.step_name == step_name
|
||||
)
|
||||
).values(
|
||||
step_data=merged_data,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
).returning(UserOnboardingProgress)
|
||||
|
||||
result = await self.db.execute(stmt)
|
||||
|
||||
if auto_commit:
|
||||
await self.db.commit()
|
||||
else:
|
||||
await self.db.flush()
|
||||
|
||||
return result.scalars().first()
|
||||
else:
|
||||
# Create new step with data but not completed
|
||||
return await self.upsert_user_step(
|
||||
user_id=user_id,
|
||||
step_name=step_name,
|
||||
completed=False,
|
||||
step_data=step_data,
|
||||
auto_commit=auto_commit
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving step data for {step_name}, user {user_id}: {e}")
|
||||
if auto_commit:
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def get_step_data(self, user_id: str, step_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get data for a specific step"""
|
||||
try:
|
||||
step = await self.get_user_step(user_id, step_name)
|
||||
return step.step_data if step else None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting step data for {step_name}, user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_subscription_parameters(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get subscription parameters saved during onboarding for tenant creation"""
|
||||
try:
|
||||
step_data = await self.get_step_data(user_id, "user_registered")
|
||||
if step_data:
|
||||
# Extract subscription-related parameters
|
||||
subscription_params = {
|
||||
"subscription_plan": step_data.get("subscription_plan", "starter"),
|
||||
"billing_cycle": step_data.get("billing_cycle", "monthly"),
|
||||
"coupon_code": step_data.get("coupon_code"),
|
||||
"payment_method_id": step_data.get("payment_method_id"),
|
||||
"payment_customer_id": step_data.get("payment_customer_id"),
|
||||
"saved_at": step_data.get("saved_at")
|
||||
}
|
||||
return subscription_params
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting subscription parameters for user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_completion_stats(self) -> Dict[str, Any]:
|
||||
"""Get completion statistics across all users"""
|
||||
try:
|
||||
# Get total users with onboarding data
|
||||
total_result = await self.db.execute(
|
||||
select(UserOnboardingSummary).count()
|
||||
)
|
||||
total_users = total_result.scalar()
|
||||
|
||||
# Get completed users
|
||||
completed_result = await self.db.execute(
|
||||
select(UserOnboardingSummary)
|
||||
.where(UserOnboardingSummary.fully_completed == True)
|
||||
.count()
|
||||
)
|
||||
completed_users = completed_result.scalar()
|
||||
|
||||
return {
|
||||
"total_users_in_onboarding": total_users,
|
||||
"fully_completed_users": completed_users,
|
||||
"completion_rate": (completed_users / total_users * 100) if total_users > 0 else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting completion stats: {e}")
|
||||
return {
|
||||
"total_users_in_onboarding": 0,
|
||||
"fully_completed_users": 0,
|
||||
"completion_rate": 0
|
||||
}
|
||||
124
services/auth/app/repositories/password_reset_repository.py
Normal file
124
services/auth/app/repositories/password_reset_repository.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# services/auth/app/repositories/password_reset_repository.py
|
||||
"""
|
||||
Password reset token repository
|
||||
Repository for password reset token operations
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from .base import AuthBaseRepository
|
||||
from app.models.password_reset_tokens import PasswordResetToken
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class PasswordResetTokenRepository(AuthBaseRepository):
|
||||
"""Repository for password reset token operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(PasswordResetToken, session)
|
||||
|
||||
async def create_token(self, user_id: str, token: str, expires_at: datetime) -> PasswordResetToken:
|
||||
"""Create a new password reset token"""
|
||||
try:
|
||||
token_data = {
|
||||
"user_id": user_id,
|
||||
"token": token,
|
||||
"expires_at": expires_at,
|
||||
"is_used": False
|
||||
}
|
||||
|
||||
reset_token = await self.create(token_data)
|
||||
|
||||
logger.debug("Password reset token created",
|
||||
user_id=user_id,
|
||||
token_id=reset_token.id,
|
||||
expires_at=expires_at)
|
||||
|
||||
return reset_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create password reset token",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create password reset token: {str(e)}")
|
||||
|
||||
async def get_token_by_value(self, token: str) -> Optional[PasswordResetToken]:
|
||||
"""Get password reset token by token value"""
|
||||
try:
|
||||
stmt = select(PasswordResetToken).where(
|
||||
and_(
|
||||
PasswordResetToken.token == token,
|
||||
PasswordResetToken.is_used == False,
|
||||
PasswordResetToken.expires_at > datetime.now(timezone.utc)
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get password reset token by value", error=str(e))
|
||||
raise DatabaseError(f"Failed to get password reset token: {str(e)}")
|
||||
|
||||
async def mark_token_as_used(self, token_id: str) -> Optional[PasswordResetToken]:
|
||||
"""Mark a password reset token as used"""
|
||||
try:
|
||||
return await self.update(token_id, {
|
||||
"is_used": True,
|
||||
"used_at": datetime.now(timezone.utc)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Failed to mark password reset token as used",
|
||||
token_id=token_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to mark token as used: {str(e)}")
|
||||
|
||||
async def cleanup_expired_tokens(self) -> int:
|
||||
"""Clean up expired password reset tokens"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Delete expired tokens
|
||||
query = text("""
|
||||
DELETE FROM password_reset_tokens
|
||||
WHERE expires_at < :now OR is_used = true
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {"now": now})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up expired password reset tokens",
|
||||
deleted_count=deleted_count)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired password reset tokens", error=str(e))
|
||||
raise DatabaseError(f"Token cleanup failed: {str(e)}")
|
||||
|
||||
async def get_valid_token_for_user(self, user_id: str) -> Optional[PasswordResetToken]:
|
||||
"""Get a valid (unused, not expired) password reset token for a user"""
|
||||
try:
|
||||
stmt = select(PasswordResetToken).where(
|
||||
and_(
|
||||
PasswordResetToken.user_id == user_id,
|
||||
PasswordResetToken.is_used == False,
|
||||
PasswordResetToken.expires_at > datetime.now(timezone.utc)
|
||||
)
|
||||
).order_by(PasswordResetToken.created_at.desc())
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get valid token for user",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get valid token for user: {str(e)}")
|
||||
305
services/auth/app/repositories/token_repository.py
Normal file
305
services/auth/app/repositories/token_repository.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
Token Repository
|
||||
Repository for refresh token operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import AuthBaseRepository
|
||||
from app.models.tokens import RefreshToken
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TokenRepository(AuthBaseRepository):
|
||||
"""Repository for refresh token operations"""
|
||||
|
||||
def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
# Tokens change frequently, shorter cache time
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def create_token(self, token_data: Dict[str, Any]) -> RefreshToken:
|
||||
"""Create a new refresh token from dictionary data"""
|
||||
return await self.create(token_data)
|
||||
|
||||
async def create_refresh_token(
|
||||
self,
|
||||
user_id: str,
|
||||
token: str,
|
||||
expires_at: datetime
|
||||
) -> RefreshToken:
|
||||
"""Create a new refresh token"""
|
||||
try:
|
||||
token_data = {
|
||||
"user_id": user_id,
|
||||
"token": token,
|
||||
"expires_at": expires_at,
|
||||
"is_revoked": False
|
||||
}
|
||||
|
||||
refresh_token = await self.create(token_data)
|
||||
|
||||
logger.debug("Refresh token created",
|
||||
user_id=user_id,
|
||||
token_id=refresh_token.id,
|
||||
expires_at=expires_at)
|
||||
|
||||
return refresh_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create refresh token",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create refresh token: {str(e)}")
|
||||
|
||||
async def get_token_by_value(self, token: str) -> Optional[RefreshToken]:
|
||||
"""Get refresh token by token value"""
|
||||
try:
|
||||
return await self.get_by_field("token", token)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get token by value", error=str(e))
|
||||
raise DatabaseError(f"Failed to get token: {str(e)}")
|
||||
|
||||
async def get_active_tokens_for_user(self, user_id: str) -> List[RefreshToken]:
|
||||
"""Get all active (non-revoked, non-expired) tokens for a user"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Use raw query for complex filtering
|
||||
query = text("""
|
||||
SELECT * FROM refresh_tokens
|
||||
WHERE user_id = :user_id
|
||||
AND is_revoked = false
|
||||
AND expires_at > :now
|
||||
ORDER BY created_at DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"user_id": user_id,
|
||||
"now": now
|
||||
})
|
||||
|
||||
# Convert rows to RefreshToken objects
|
||||
tokens = []
|
||||
for row in result.fetchall():
|
||||
token = RefreshToken(
|
||||
id=row.id,
|
||||
user_id=row.user_id,
|
||||
token=row.token,
|
||||
expires_at=row.expires_at,
|
||||
is_revoked=row.is_revoked,
|
||||
created_at=row.created_at,
|
||||
revoked_at=row.revoked_at
|
||||
)
|
||||
tokens.append(token)
|
||||
|
||||
return tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get active tokens for user",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get active tokens: {str(e)}")
|
||||
|
||||
async def revoke_token(self, token_id: str) -> Optional[RefreshToken]:
|
||||
"""Revoke a refresh token"""
|
||||
try:
|
||||
return await self.update(token_id, {
|
||||
"is_revoked": True,
|
||||
"revoked_at": datetime.now(timezone.utc)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Failed to revoke token",
|
||||
token_id=token_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to revoke token: {str(e)}")
|
||||
|
||||
async def revoke_all_user_tokens(self, user_id: str) -> int:
|
||||
"""Revoke all tokens for a user"""
|
||||
try:
|
||||
# Use bulk update for efficiency
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
query = text("""
|
||||
UPDATE refresh_tokens
|
||||
SET is_revoked = true, revoked_at = :revoked_at
|
||||
WHERE user_id = :user_id AND is_revoked = false
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"user_id": user_id,
|
||||
"revoked_at": now
|
||||
})
|
||||
|
||||
revoked_count = result.rowcount
|
||||
|
||||
logger.info("Revoked all user tokens",
|
||||
user_id=user_id,
|
||||
revoked_count=revoked_count)
|
||||
|
||||
return revoked_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to revoke all user tokens",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to revoke user tokens: {str(e)}")
|
||||
|
||||
async def is_token_valid(self, token: str) -> bool:
|
||||
"""Check if a token is valid (exists, not revoked, not expired)"""
|
||||
try:
|
||||
refresh_token = await self.get_token_by_value(token)
|
||||
|
||||
if not refresh_token:
|
||||
return False
|
||||
|
||||
if refresh_token.is_revoked:
|
||||
return False
|
||||
|
||||
if refresh_token.expires_at < datetime.now(timezone.utc):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate token", error=str(e))
|
||||
return False
|
||||
|
||||
async def validate_refresh_token(self, token: str, user_id: str) -> bool:
|
||||
"""Validate refresh token for a specific user"""
|
||||
try:
|
||||
refresh_token = await self.get_token_by_value(token)
|
||||
|
||||
if not refresh_token:
|
||||
logger.debug("Refresh token not found", token_prefix=token[:10] + "...")
|
||||
return False
|
||||
|
||||
# Convert both to strings for comparison to handle UUID vs string mismatch
|
||||
token_user_id = str(refresh_token.user_id)
|
||||
expected_user_id = str(user_id)
|
||||
|
||||
if token_user_id != expected_user_id:
|
||||
logger.warning("Refresh token user_id mismatch",
|
||||
expected_user_id=expected_user_id,
|
||||
actual_user_id=token_user_id)
|
||||
return False
|
||||
|
||||
if refresh_token.is_revoked:
|
||||
logger.debug("Refresh token is revoked", user_id=user_id)
|
||||
return False
|
||||
|
||||
if refresh_token.expires_at < datetime.now(timezone.utc):
|
||||
logger.debug("Refresh token is expired", user_id=user_id)
|
||||
return False
|
||||
|
||||
logger.debug("Refresh token is valid", user_id=user_id)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate refresh token",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return False
|
||||
|
||||
async def cleanup_expired_tokens(self) -> int:
|
||||
"""Clean up expired refresh tokens"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Delete expired tokens
|
||||
query = text("""
|
||||
DELETE FROM refresh_tokens
|
||||
WHERE expires_at < :now
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {"now": now})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up expired tokens",
|
||||
deleted_count=deleted_count)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired tokens", error=str(e))
|
||||
raise DatabaseError(f"Token cleanup failed: {str(e)}")
|
||||
|
||||
async def cleanup_old_revoked_tokens(self, days_old: int = 30) -> int:
|
||||
"""Clean up old revoked tokens"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
|
||||
|
||||
query = text("""
|
||||
DELETE FROM refresh_tokens
|
||||
WHERE is_revoked = true
|
||||
AND revoked_at < :cutoff_date
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"cutoff_date": cutoff_date
|
||||
})
|
||||
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up old revoked tokens",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old revoked tokens",
|
||||
days_old=days_old,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Revoked token cleanup failed: {str(e)}")
|
||||
|
||||
async def get_token_statistics(self) -> Dict[str, Any]:
|
||||
"""Get token statistics"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Get counts with raw queries
|
||||
stats_query = text("""
|
||||
SELECT
|
||||
COUNT(*) as total_tokens,
|
||||
COUNT(CASE WHEN is_revoked = false AND expires_at > :now THEN 1 END) as active_tokens,
|
||||
COUNT(CASE WHEN is_revoked = true THEN 1 END) as revoked_tokens,
|
||||
COUNT(CASE WHEN expires_at <= :now THEN 1 END) as expired_tokens,
|
||||
COUNT(DISTINCT user_id) as users_with_tokens
|
||||
FROM refresh_tokens
|
||||
""")
|
||||
|
||||
result = await self.session.execute(stats_query, {"now": now})
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
return {
|
||||
"total_tokens": row.total_tokens,
|
||||
"active_tokens": row.active_tokens,
|
||||
"revoked_tokens": row.revoked_tokens,
|
||||
"expired_tokens": row.expired_tokens,
|
||||
"users_with_tokens": row.users_with_tokens
|
||||
}
|
||||
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"active_tokens": 0,
|
||||
"revoked_tokens": 0,
|
||||
"expired_tokens": 0,
|
||||
"users_with_tokens": 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get token statistics", error=str(e))
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"active_tokens": 0,
|
||||
"revoked_tokens": 0,
|
||||
"expired_tokens": 0,
|
||||
"users_with_tokens": 0
|
||||
}
|
||||
277
services/auth/app/repositories/user_repository.py
Normal file
277
services/auth/app/repositories/user_repository.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
User Repository
|
||||
Repository for user operations with authentication-specific queries
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, func, desc, text
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import AuthBaseRepository
|
||||
from app.models.users import User
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class UserRepository(AuthBaseRepository):
|
||||
"""Repository for user operations"""
|
||||
|
||||
def __init__(self, model, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def create_user(self, user_data: Dict[str, Any]) -> User:
|
||||
"""Create a new user with validation"""
|
||||
try:
|
||||
# Validate user data
|
||||
validation_result = self._validate_auth_data(
|
||||
user_data,
|
||||
["email", "hashed_password", "full_name", "role"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid user data: {validation_result['errors']}")
|
||||
|
||||
# Check if user already exists
|
||||
existing_user = await self.get_by_email(user_data["email"])
|
||||
if existing_user:
|
||||
raise DuplicateRecordError(f"User with email {user_data['email']} already exists")
|
||||
|
||||
# Create user
|
||||
user = await self.create(user_data)
|
||||
|
||||
logger.info("User created successfully",
|
||||
user_id=user.id,
|
||||
email=user.email,
|
||||
role=user.role)
|
||||
|
||||
return user
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create user",
|
||||
email=user_data.get("email"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create user: {str(e)}")
|
||||
|
||||
async def get_user_by_email(self, email: str) -> Optional[User]:
|
||||
"""Get user by email address"""
|
||||
return await self.get_by_email(email)
|
||||
|
||||
async def get_active_users(self, skip: int = 0, limit: int = 100) -> List[User]:
|
||||
"""Get all active users"""
|
||||
return await self.get_active_records(skip=skip, limit=limit)
|
||||
|
||||
async def authenticate_user(self, email: str, password: str) -> Optional[User]:
|
||||
"""Authenticate user with email and plain password"""
|
||||
try:
|
||||
user = await self.get_by_email(email)
|
||||
|
||||
if not user:
|
||||
logger.debug("User not found for authentication", email=email)
|
||||
return None
|
||||
|
||||
if not user.is_active:
|
||||
logger.debug("User account is inactive", email=email)
|
||||
return None
|
||||
|
||||
# Verify password using security manager
|
||||
from app.core.security import SecurityManager
|
||||
if SecurityManager.verify_password(password, user.hashed_password):
|
||||
# Update last login
|
||||
await self.update_last_login(user.id)
|
||||
logger.info("User authenticated successfully",
|
||||
user_id=user.id,
|
||||
email=email)
|
||||
return user
|
||||
|
||||
logger.debug("Invalid password for user", email=email)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Authentication failed",
|
||||
email=email,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Authentication failed: {str(e)}")
|
||||
|
||||
async def update_last_login(self, user_id: str) -> Optional[User]:
|
||||
"""Update user's last login timestamp"""
|
||||
try:
|
||||
return await self.update(user_id, {
|
||||
"last_login": datetime.now(timezone.utc)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Failed to update last login",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
# Don't raise here - last login update is not critical
|
||||
return None
|
||||
|
||||
async def update_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> Optional[User]:
|
||||
"""Update user profile information"""
|
||||
try:
|
||||
# Remove sensitive fields that shouldn't be updated via profile
|
||||
profile_data.pop("id", None)
|
||||
profile_data.pop("hashed_password", None)
|
||||
profile_data.pop("created_at", None)
|
||||
profile_data.pop("is_active", None)
|
||||
|
||||
# Validate email if being updated
|
||||
if "email" in profile_data:
|
||||
validation_result = self._validate_auth_data(
|
||||
profile_data,
|
||||
["email"]
|
||||
)
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid profile data: {validation_result['errors']}")
|
||||
|
||||
# Check for email conflicts
|
||||
existing_user = await self.get_by_email(profile_data["email"])
|
||||
if existing_user and str(existing_user.id) != str(user_id):
|
||||
raise DuplicateRecordError(f"Email {profile_data['email']} is already in use")
|
||||
|
||||
updated_user = await self.update(user_id, profile_data)
|
||||
|
||||
if updated_user:
|
||||
logger.info("User profile updated",
|
||||
user_id=user_id,
|
||||
updated_fields=list(profile_data.keys()))
|
||||
|
||||
return updated_user
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update user profile",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update profile: {str(e)}")
|
||||
|
||||
async def change_password(self, user_id: str, new_password_hash: str) -> bool:
|
||||
"""Change user password"""
|
||||
try:
|
||||
updated_user = await self.update(user_id, {
|
||||
"hashed_password": new_password_hash
|
||||
})
|
||||
|
||||
if updated_user:
|
||||
logger.info("Password changed successfully", user_id=user_id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to change password",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to change password: {str(e)}")
|
||||
|
||||
async def verify_user_email(self, user_id: str) -> Optional[User]:
|
||||
"""Mark user email as verified"""
|
||||
try:
|
||||
return await self.update(user_id, {
|
||||
"is_verified": True
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Failed to verify user email",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to verify email: {str(e)}")
|
||||
|
||||
async def deactivate_user(self, user_id: str) -> Optional[User]:
|
||||
"""Deactivate user account"""
|
||||
return await self.deactivate_record(user_id)
|
||||
|
||||
async def activate_user(self, user_id: str) -> Optional[User]:
|
||||
"""Activate user account"""
|
||||
return await self.activate_record(user_id)
|
||||
|
||||
async def get_users_by_role(self, role: str, skip: int = 0, limit: int = 100) -> List[User]:
|
||||
"""Get users by role"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"role": role, "is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get users by role",
|
||||
role=role,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get users by role: {str(e)}")
|
||||
|
||||
async def search_users(self, search_term: str, skip: int = 0, limit: int = 50) -> List[User]:
|
||||
"""Search users by email or full name"""
|
||||
try:
|
||||
return await self.search(
|
||||
search_term=search_term,
|
||||
search_fields=["email", "full_name"],
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to search users",
|
||||
search_term=search_term,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to search users: {str(e)}")
|
||||
|
||||
async def get_user_statistics(self) -> Dict[str, Any]:
|
||||
"""Get user statistics"""
|
||||
try:
|
||||
# Get basic counts
|
||||
total_users = await self.count()
|
||||
active_users = await self.count(filters={"is_active": True})
|
||||
verified_users = await self.count(filters={"is_verified": True})
|
||||
|
||||
# Get users by role using raw query
|
||||
role_query = text("""
|
||||
SELECT role, COUNT(*) as count
|
||||
FROM users
|
||||
WHERE is_active = true
|
||||
GROUP BY role
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(role_query)
|
||||
role_stats = {row.role: row.count for row in result.fetchall()}
|
||||
|
||||
# Recent activity (users created in last 30 days)
|
||||
thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
recent_users_query = text("""
|
||||
SELECT COUNT(*) as count
|
||||
FROM users
|
||||
WHERE created_at >= :thirty_days_ago
|
||||
""")
|
||||
|
||||
recent_result = await self.session.execute(
|
||||
recent_users_query,
|
||||
{"thirty_days_ago": thirty_days_ago}
|
||||
)
|
||||
recent_users = recent_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"total_users": total_users,
|
||||
"active_users": active_users,
|
||||
"inactive_users": total_users - active_users,
|
||||
"verified_users": verified_users,
|
||||
"unverified_users": active_users - verified_users,
|
||||
"recent_registrations": recent_users,
|
||||
"users_by_role": role_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user statistics", error=str(e))
|
||||
return {
|
||||
"total_users": 0,
|
||||
"active_users": 0,
|
||||
"inactive_users": 0,
|
||||
"verified_users": 0,
|
||||
"unverified_users": 0,
|
||||
"recent_registrations": 0,
|
||||
"users_by_role": {}
|
||||
}
|
||||
Reference in New Issue
Block a user