2025-08-11 07:01:08 +02:00
|
|
|
# services/auth/app/repositories/onboarding_repository.py
|
|
|
|
|
"""
|
|
|
|
|
Onboarding Repository for database operations
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
from sqlalchemy import select, update, delete, and_
|
|
|
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
|
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
import structlog
|
|
|
|
|
|
|
|
|
|
from app.models.onboarding import UserOnboardingProgress, UserOnboardingSummary
|
|
|
|
|
|
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
|
|
|
|
class OnboardingRepository:
|
|
|
|
|
"""Repository for onboarding progress operations"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
|
|
|
self.db = db
|
|
|
|
|
|
|
|
|
|
async def get_user_progress_steps(self, user_id: str) -> List[UserOnboardingProgress]:
|
|
|
|
|
"""Get all onboarding steps for a user"""
|
|
|
|
|
try:
|
|
|
|
|
result = await self.db.execute(
|
|
|
|
|
select(UserOnboardingProgress)
|
|
|
|
|
.where(UserOnboardingProgress.user_id == user_id)
|
|
|
|
|
.order_by(UserOnboardingProgress.created_at)
|
|
|
|
|
)
|
|
|
|
|
return result.scalars().all()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error getting user progress steps for {user_id}: {e}")
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
async def get_user_step(self, user_id: str, step_name: str) -> Optional[UserOnboardingProgress]:
|
|
|
|
|
"""Get a specific step for a user"""
|
|
|
|
|
try:
|
|
|
|
|
result = await self.db.execute(
|
|
|
|
|
select(UserOnboardingProgress)
|
|
|
|
|
.where(
|
|
|
|
|
and_(
|
|
|
|
|
UserOnboardingProgress.user_id == user_id,
|
|
|
|
|
UserOnboardingProgress.step_name == step_name
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return result.scalars().first()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error getting step {step_name} for user {user_id}: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
async def upsert_user_step(
|
2025-10-07 07:15:07 +02:00
|
|
|
self,
|
|
|
|
|
user_id: str,
|
|
|
|
|
step_name: str,
|
|
|
|
|
completed: bool,
|
|
|
|
|
step_data: Dict[str, Any] = None,
|
|
|
|
|
auto_commit: bool = True
|
2025-08-11 07:01:08 +02:00
|
|
|
) -> UserOnboardingProgress:
|
2025-10-07 07:15:07 +02:00
|
|
|
"""Insert or update a user's onboarding step
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id: User ID
|
|
|
|
|
step_name: Name of the step
|
|
|
|
|
completed: Whether the step is completed
|
|
|
|
|
step_data: Additional data for the step
|
|
|
|
|
auto_commit: Whether to auto-commit (set to False when used within UnitOfWork)
|
|
|
|
|
"""
|
2025-08-11 07:01:08 +02:00
|
|
|
try:
|
|
|
|
|
completed_at = datetime.now(timezone.utc) if completed else None
|
|
|
|
|
step_data = step_data or {}
|
2025-10-07 07:15:07 +02:00
|
|
|
|
2025-08-11 07:01:08 +02:00
|
|
|
# Use PostgreSQL UPSERT (INSERT ... ON CONFLICT ... DO UPDATE)
|
|
|
|
|
stmt = insert(UserOnboardingProgress).values(
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
step_name=step_name,
|
|
|
|
|
completed=completed,
|
|
|
|
|
completed_at=completed_at,
|
|
|
|
|
step_data=step_data,
|
|
|
|
|
updated_at=datetime.now(timezone.utc)
|
|
|
|
|
)
|
2025-10-07 07:15:07 +02:00
|
|
|
|
2025-08-11 07:01:08 +02:00
|
|
|
# On conflict, update the existing record
|
|
|
|
|
stmt = stmt.on_conflict_do_update(
|
|
|
|
|
index_elements=['user_id', 'step_name'],
|
|
|
|
|
set_=dict(
|
|
|
|
|
completed=stmt.excluded.completed,
|
|
|
|
|
completed_at=stmt.excluded.completed_at,
|
|
|
|
|
step_data=stmt.excluded.step_data,
|
|
|
|
|
updated_at=stmt.excluded.updated_at
|
|
|
|
|
)
|
|
|
|
|
)
|
2025-10-07 07:15:07 +02:00
|
|
|
|
2025-08-11 07:01:08 +02:00
|
|
|
# Return the updated record
|
|
|
|
|
stmt = stmt.returning(UserOnboardingProgress)
|
|
|
|
|
result = await self.db.execute(stmt)
|
2025-10-07 07:15:07 +02:00
|
|
|
|
|
|
|
|
# Only commit if auto_commit is True (not within a UnitOfWork)
|
|
|
|
|
if auto_commit:
|
|
|
|
|
await self.db.commit()
|
|
|
|
|
else:
|
|
|
|
|
# Flush to ensure the statement is executed
|
|
|
|
|
await self.db.flush()
|
|
|
|
|
|
2025-08-11 07:01:08 +02:00
|
|
|
return result.scalars().first()
|
2025-10-07 07:15:07 +02:00
|
|
|
|
2025-08-11 07:01:08 +02:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error upserting step {step_name} for user {user_id}: {e}")
|
2025-10-07 07:15:07 +02:00
|
|
|
if auto_commit:
|
|
|
|
|
await self.db.rollback()
|
2025-08-11 07:01:08 +02:00
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
async def get_user_summary(self, user_id: str) -> Optional[UserOnboardingSummary]:
|
|
|
|
|
"""Get user's onboarding summary"""
|
|
|
|
|
try:
|
|
|
|
|
result = await self.db.execute(
|
|
|
|
|
select(UserOnboardingSummary)
|
|
|
|
|
.where(UserOnboardingSummary.user_id == user_id)
|
|
|
|
|
)
|
|
|
|
|
return result.scalars().first()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error getting onboarding summary for user {user_id}: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
async def upsert_user_summary(
|
|
|
|
|
self,
|
|
|
|
|
user_id: str,
|
|
|
|
|
current_step: str,
|
|
|
|
|
next_step: Optional[str],
|
|
|
|
|
completion_percentage: float,
|
|
|
|
|
fully_completed: bool,
|
|
|
|
|
steps_completed_count: str
|
|
|
|
|
) -> UserOnboardingSummary:
|
|
|
|
|
"""Insert or update user's onboarding summary"""
|
|
|
|
|
try:
|
|
|
|
|
# Use PostgreSQL UPSERT
|
|
|
|
|
stmt = insert(UserOnboardingSummary).values(
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
current_step=current_step,
|
|
|
|
|
next_step=next_step,
|
|
|
|
|
completion_percentage=str(completion_percentage),
|
|
|
|
|
fully_completed=fully_completed,
|
|
|
|
|
steps_completed_count=steps_completed_count,
|
|
|
|
|
updated_at=datetime.now(timezone.utc),
|
|
|
|
|
last_activity_at=datetime.now(timezone.utc)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# On conflict, update the existing record
|
|
|
|
|
stmt = stmt.on_conflict_do_update(
|
|
|
|
|
index_elements=['user_id'],
|
|
|
|
|
set_=dict(
|
|
|
|
|
current_step=stmt.excluded.current_step,
|
|
|
|
|
next_step=stmt.excluded.next_step,
|
|
|
|
|
completion_percentage=stmt.excluded.completion_percentage,
|
|
|
|
|
fully_completed=stmt.excluded.fully_completed,
|
|
|
|
|
steps_completed_count=stmt.excluded.steps_completed_count,
|
|
|
|
|
updated_at=stmt.excluded.updated_at,
|
|
|
|
|
last_activity_at=stmt.excluded.last_activity_at
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Return the updated record
|
|
|
|
|
stmt = stmt.returning(UserOnboardingSummary)
|
|
|
|
|
result = await self.db.execute(stmt)
|
|
|
|
|
await self.db.commit()
|
|
|
|
|
|
|
|
|
|
return result.scalars().first()
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error upserting summary for user {user_id}: {e}")
|
|
|
|
|
await self.db.rollback()
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
async def delete_user_progress(self, user_id: str) -> bool:
|
|
|
|
|
"""Delete all onboarding progress for a user"""
|
|
|
|
|
try:
|
|
|
|
|
# Delete steps
|
|
|
|
|
await self.db.execute(
|
|
|
|
|
delete(UserOnboardingProgress)
|
|
|
|
|
.where(UserOnboardingProgress.user_id == user_id)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Delete summary
|
|
|
|
|
await self.db.execute(
|
|
|
|
|
delete(UserOnboardingSummary)
|
|
|
|
|
.where(UserOnboardingSummary.user_id == user_id)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
await self.db.commit()
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error deleting progress for user {user_id}: {e}")
|
|
|
|
|
await self.db.rollback()
|
|
|
|
|
return False
|
|
|
|
|
|
2025-10-01 21:56:38 +02:00
|
|
|
async def save_step_data(
|
|
|
|
|
self,
|
|
|
|
|
user_id: str,
|
|
|
|
|
step_name: str,
|
2026-01-13 22:22:38 +01:00
|
|
|
step_data: Dict[str, Any],
|
|
|
|
|
auto_commit: bool = True
|
2025-10-01 21:56:38 +02:00
|
|
|
) -> UserOnboardingProgress:
|
2026-01-13 22:22:38 +01:00
|
|
|
"""Save data for a specific step without marking it as completed
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id: User ID
|
|
|
|
|
step_name: Name of the step
|
|
|
|
|
step_data: Data to save
|
|
|
|
|
auto_commit: Whether to auto-commit (set to False when used within UnitOfWork)
|
|
|
|
|
"""
|
2025-10-01 21:56:38 +02:00
|
|
|
try:
|
|
|
|
|
# Get existing step or create new one
|
|
|
|
|
existing_step = await self.get_user_step(user_id, step_name)
|
|
|
|
|
|
|
|
|
|
if existing_step:
|
|
|
|
|
# Update existing step data (merge with existing data)
|
|
|
|
|
merged_data = {**(existing_step.step_data or {}), **step_data}
|
|
|
|
|
|
|
|
|
|
stmt = update(UserOnboardingProgress).where(
|
|
|
|
|
and_(
|
|
|
|
|
UserOnboardingProgress.user_id == user_id,
|
|
|
|
|
UserOnboardingProgress.step_name == step_name
|
|
|
|
|
)
|
|
|
|
|
).values(
|
|
|
|
|
step_data=merged_data,
|
|
|
|
|
updated_at=datetime.now(timezone.utc)
|
|
|
|
|
).returning(UserOnboardingProgress)
|
|
|
|
|
|
|
|
|
|
result = await self.db.execute(stmt)
|
2026-01-13 22:22:38 +01:00
|
|
|
|
|
|
|
|
if auto_commit:
|
|
|
|
|
await self.db.commit()
|
|
|
|
|
else:
|
|
|
|
|
await self.db.flush()
|
|
|
|
|
|
2025-10-01 21:56:38 +02:00
|
|
|
return result.scalars().first()
|
|
|
|
|
else:
|
|
|
|
|
# Create new step with data but not completed
|
|
|
|
|
return await self.upsert_user_step(
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
step_name=step_name,
|
|
|
|
|
completed=False,
|
2026-01-13 22:22:38 +01:00
|
|
|
step_data=step_data,
|
|
|
|
|
auto_commit=auto_commit
|
2025-10-01 21:56:38 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error saving step data for {step_name}, user {user_id}: {e}")
|
2026-01-13 22:22:38 +01:00
|
|
|
if auto_commit:
|
|
|
|
|
await self.db.rollback()
|
2025-10-01 21:56:38 +02:00
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
async def get_step_data(self, user_id: str, step_name: str) -> Optional[Dict[str, Any]]:
|
|
|
|
|
"""Get data for a specific step"""
|
|
|
|
|
try:
|
|
|
|
|
step = await self.get_user_step(user_id, step_name)
|
|
|
|
|
return step.step_data if step else None
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error getting step data for {step_name}, user {user_id}: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
2026-01-13 22:22:38 +01:00
|
|
|
async def get_subscription_parameters(self, user_id: str) -> Optional[Dict[str, Any]]:
|
|
|
|
|
"""Get subscription parameters saved during onboarding for tenant creation"""
|
|
|
|
|
try:
|
|
|
|
|
step_data = await self.get_step_data(user_id, "user_registered")
|
|
|
|
|
if step_data:
|
|
|
|
|
# Extract subscription-related parameters
|
|
|
|
|
subscription_params = {
|
|
|
|
|
"subscription_plan": step_data.get("subscription_plan", "starter"),
|
|
|
|
|
"billing_cycle": step_data.get("billing_cycle", "monthly"),
|
|
|
|
|
"coupon_code": step_data.get("coupon_code"),
|
|
|
|
|
"payment_method_id": step_data.get("payment_method_id"),
|
|
|
|
|
"payment_customer_id": step_data.get("payment_customer_id"),
|
|
|
|
|
"saved_at": step_data.get("saved_at")
|
|
|
|
|
}
|
|
|
|
|
return subscription_params
|
|
|
|
|
return None
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error getting subscription parameters for user {user_id}: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
2025-08-11 07:01:08 +02:00
|
|
|
async def get_completion_stats(self) -> Dict[str, Any]:
|
|
|
|
|
"""Get completion statistics across all users"""
|
|
|
|
|
try:
|
|
|
|
|
# Get total users with onboarding data
|
|
|
|
|
total_result = await self.db.execute(
|
|
|
|
|
select(UserOnboardingSummary).count()
|
|
|
|
|
)
|
|
|
|
|
total_users = total_result.scalar()
|
2025-10-01 21:56:38 +02:00
|
|
|
|
2025-08-11 07:01:08 +02:00
|
|
|
# Get completed users
|
|
|
|
|
completed_result = await self.db.execute(
|
|
|
|
|
select(UserOnboardingSummary)
|
|
|
|
|
.where(UserOnboardingSummary.fully_completed == True)
|
|
|
|
|
.count()
|
|
|
|
|
)
|
|
|
|
|
completed_users = completed_result.scalar()
|
2025-10-01 21:56:38 +02:00
|
|
|
|
2025-08-11 07:01:08 +02:00
|
|
|
return {
|
|
|
|
|
"total_users_in_onboarding": total_users,
|
|
|
|
|
"fully_completed_users": completed_users,
|
|
|
|
|
"completion_rate": (completed_users / total_users * 100) if total_users > 0 else 0
|
|
|
|
|
}
|
2025-10-01 21:56:38 +02:00
|
|
|
|
2025-08-11 07:01:08 +02:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error getting completion stats: {e}")
|
|
|
|
|
return {
|
|
|
|
|
"total_users_in_onboarding": 0,
|
|
|
|
|
"fully_completed_users": 0,
|
|
|
|
|
"completion_rate": 0
|
|
|
|
|
}
|