Initial commit - production deployment
This commit is contained in:
313
services/auth/app/repositories/onboarding_repository.py
Normal file
313
services/auth/app/repositories/onboarding_repository.py
Normal file
@@ -0,0 +1,313 @@
|
||||
# services/auth/app/repositories/onboarding_repository.py
|
||||
"""
|
||||
Onboarding Repository for database operations
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, delete, and_
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
|
||||
from app.models.onboarding import UserOnboardingProgress, UserOnboardingSummary
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class OnboardingRepository:
|
||||
"""Repository for onboarding progress operations"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_user_progress_steps(self, user_id: str) -> List[UserOnboardingProgress]:
|
||||
"""Get all onboarding steps for a user"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(UserOnboardingProgress)
|
||||
.where(UserOnboardingProgress.user_id == user_id)
|
||||
.order_by(UserOnboardingProgress.created_at)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user progress steps for {user_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_user_step(self, user_id: str, step_name: str) -> Optional[UserOnboardingProgress]:
|
||||
"""Get a specific step for a user"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(UserOnboardingProgress)
|
||||
.where(
|
||||
and_(
|
||||
UserOnboardingProgress.user_id == user_id,
|
||||
UserOnboardingProgress.step_name == step_name
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting step {step_name} for user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def upsert_user_step(
|
||||
self,
|
||||
user_id: str,
|
||||
step_name: str,
|
||||
completed: bool,
|
||||
step_data: Dict[str, Any] = None,
|
||||
auto_commit: bool = True
|
||||
) -> UserOnboardingProgress:
|
||||
"""Insert or update a user's onboarding step
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
step_name: Name of the step
|
||||
completed: Whether the step is completed
|
||||
step_data: Additional data for the step
|
||||
auto_commit: Whether to auto-commit (set to False when used within UnitOfWork)
|
||||
"""
|
||||
try:
|
||||
completed_at = datetime.now(timezone.utc) if completed else None
|
||||
step_data = step_data or {}
|
||||
|
||||
# Use PostgreSQL UPSERT (INSERT ... ON CONFLICT ... DO UPDATE)
|
||||
stmt = insert(UserOnboardingProgress).values(
|
||||
user_id=user_id,
|
||||
step_name=step_name,
|
||||
completed=completed,
|
||||
completed_at=completed_at,
|
||||
step_data=step_data,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# On conflict, update the existing record
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=['user_id', 'step_name'],
|
||||
set_=dict(
|
||||
completed=stmt.excluded.completed,
|
||||
completed_at=stmt.excluded.completed_at,
|
||||
step_data=stmt.excluded.step_data,
|
||||
updated_at=stmt.excluded.updated_at
|
||||
)
|
||||
)
|
||||
|
||||
# Return the updated record
|
||||
stmt = stmt.returning(UserOnboardingProgress)
|
||||
result = await self.db.execute(stmt)
|
||||
|
||||
# Only commit if auto_commit is True (not within a UnitOfWork)
|
||||
if auto_commit:
|
||||
await self.db.commit()
|
||||
else:
|
||||
# Flush to ensure the statement is executed
|
||||
await self.db.flush()
|
||||
|
||||
return result.scalars().first()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error upserting step {step_name} for user {user_id}: {e}")
|
||||
if auto_commit:
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def get_user_summary(self, user_id: str) -> Optional[UserOnboardingSummary]:
|
||||
"""Get user's onboarding summary"""
|
||||
try:
|
||||
result = await self.db.execute(
|
||||
select(UserOnboardingSummary)
|
||||
.where(UserOnboardingSummary.user_id == user_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting onboarding summary for user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def upsert_user_summary(
|
||||
self,
|
||||
user_id: str,
|
||||
current_step: str,
|
||||
next_step: Optional[str],
|
||||
completion_percentage: float,
|
||||
fully_completed: bool,
|
||||
steps_completed_count: str
|
||||
) -> UserOnboardingSummary:
|
||||
"""Insert or update user's onboarding summary"""
|
||||
try:
|
||||
# Use PostgreSQL UPSERT
|
||||
stmt = insert(UserOnboardingSummary).values(
|
||||
user_id=user_id,
|
||||
current_step=current_step,
|
||||
next_step=next_step,
|
||||
completion_percentage=str(completion_percentage),
|
||||
fully_completed=fully_completed,
|
||||
steps_completed_count=steps_completed_count,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
last_activity_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# On conflict, update the existing record
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=['user_id'],
|
||||
set_=dict(
|
||||
current_step=stmt.excluded.current_step,
|
||||
next_step=stmt.excluded.next_step,
|
||||
completion_percentage=stmt.excluded.completion_percentage,
|
||||
fully_completed=stmt.excluded.fully_completed,
|
||||
steps_completed_count=stmt.excluded.steps_completed_count,
|
||||
updated_at=stmt.excluded.updated_at,
|
||||
last_activity_at=stmt.excluded.last_activity_at
|
||||
)
|
||||
)
|
||||
|
||||
# Return the updated record
|
||||
stmt = stmt.returning(UserOnboardingSummary)
|
||||
result = await self.db.execute(stmt)
|
||||
await self.db.commit()
|
||||
|
||||
return result.scalars().first()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error upserting summary for user {user_id}: {e}")
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def delete_user_progress(self, user_id: str) -> bool:
|
||||
"""Delete all onboarding progress for a user"""
|
||||
try:
|
||||
# Delete steps
|
||||
await self.db.execute(
|
||||
delete(UserOnboardingProgress)
|
||||
.where(UserOnboardingProgress.user_id == user_id)
|
||||
)
|
||||
|
||||
# Delete summary
|
||||
await self.db.execute(
|
||||
delete(UserOnboardingSummary)
|
||||
.where(UserOnboardingSummary.user_id == user_id)
|
||||
)
|
||||
|
||||
await self.db.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting progress for user {user_id}: {e}")
|
||||
await self.db.rollback()
|
||||
return False
|
||||
|
||||
async def save_step_data(
|
||||
self,
|
||||
user_id: str,
|
||||
step_name: str,
|
||||
step_data: Dict[str, Any],
|
||||
auto_commit: bool = True
|
||||
) -> UserOnboardingProgress:
|
||||
"""Save data for a specific step without marking it as completed
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
step_name: Name of the step
|
||||
step_data: Data to save
|
||||
auto_commit: Whether to auto-commit (set to False when used within UnitOfWork)
|
||||
"""
|
||||
try:
|
||||
# Get existing step or create new one
|
||||
existing_step = await self.get_user_step(user_id, step_name)
|
||||
|
||||
if existing_step:
|
||||
# Update existing step data (merge with existing data)
|
||||
merged_data = {**(existing_step.step_data or {}), **step_data}
|
||||
|
||||
stmt = update(UserOnboardingProgress).where(
|
||||
and_(
|
||||
UserOnboardingProgress.user_id == user_id,
|
||||
UserOnboardingProgress.step_name == step_name
|
||||
)
|
||||
).values(
|
||||
step_data=merged_data,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
).returning(UserOnboardingProgress)
|
||||
|
||||
result = await self.db.execute(stmt)
|
||||
|
||||
if auto_commit:
|
||||
await self.db.commit()
|
||||
else:
|
||||
await self.db.flush()
|
||||
|
||||
return result.scalars().first()
|
||||
else:
|
||||
# Create new step with data but not completed
|
||||
return await self.upsert_user_step(
|
||||
user_id=user_id,
|
||||
step_name=step_name,
|
||||
completed=False,
|
||||
step_data=step_data,
|
||||
auto_commit=auto_commit
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving step data for {step_name}, user {user_id}: {e}")
|
||||
if auto_commit:
|
||||
await self.db.rollback()
|
||||
raise
|
||||
|
||||
async def get_step_data(self, user_id: str, step_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get data for a specific step"""
|
||||
try:
|
||||
step = await self.get_user_step(user_id, step_name)
|
||||
return step.step_data if step else None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting step data for {step_name}, user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_subscription_parameters(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get subscription parameters saved during onboarding for tenant creation"""
|
||||
try:
|
||||
step_data = await self.get_step_data(user_id, "user_registered")
|
||||
if step_data:
|
||||
# Extract subscription-related parameters
|
||||
subscription_params = {
|
||||
"subscription_plan": step_data.get("subscription_plan", "starter"),
|
||||
"billing_cycle": step_data.get("billing_cycle", "monthly"),
|
||||
"coupon_code": step_data.get("coupon_code"),
|
||||
"payment_method_id": step_data.get("payment_method_id"),
|
||||
"payment_customer_id": step_data.get("payment_customer_id"),
|
||||
"saved_at": step_data.get("saved_at")
|
||||
}
|
||||
return subscription_params
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting subscription parameters for user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_completion_stats(self) -> Dict[str, Any]:
|
||||
"""Get completion statistics across all users"""
|
||||
try:
|
||||
# Get total users with onboarding data
|
||||
total_result = await self.db.execute(
|
||||
select(UserOnboardingSummary).count()
|
||||
)
|
||||
total_users = total_result.scalar()
|
||||
|
||||
# Get completed users
|
||||
completed_result = await self.db.execute(
|
||||
select(UserOnboardingSummary)
|
||||
.where(UserOnboardingSummary.fully_completed == True)
|
||||
.count()
|
||||
)
|
||||
completed_users = completed_result.scalar()
|
||||
|
||||
return {
|
||||
"total_users_in_onboarding": total_users,
|
||||
"fully_completed_users": completed_users,
|
||||
"completion_rate": (completed_users / total_users * 100) if total_users > 0 else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting completion stats: {e}")
|
||||
return {
|
||||
"total_users_in_onboarding": 0,
|
||||
"fully_completed_users": 0,
|
||||
"completion_rate": 0
|
||||
}
|
||||
Reference in New Issue
Block a user