# services/auth/app/repositories/onboarding_repository.py """ Onboarding Repository for database operations """ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update, delete, and_ from sqlalchemy.dialects.postgresql import insert from typing import List, Dict, Any, Optional from datetime import datetime, timezone import structlog from app.models.onboarding import UserOnboardingProgress, UserOnboardingSummary logger = structlog.get_logger() class OnboardingRepository: """Repository for onboarding progress operations""" def __init__(self, db: AsyncSession): self.db = db async def get_user_progress_steps(self, user_id: str) -> List[UserOnboardingProgress]: """Get all onboarding steps for a user""" try: result = await self.db.execute( select(UserOnboardingProgress) .where(UserOnboardingProgress.user_id == user_id) .order_by(UserOnboardingProgress.created_at) ) return result.scalars().all() except Exception as e: logger.error(f"Error getting user progress steps for {user_id}: {e}") return [] async def get_user_step(self, user_id: str, step_name: str) -> Optional[UserOnboardingProgress]: """Get a specific step for a user""" try: result = await self.db.execute( select(UserOnboardingProgress) .where( and_( UserOnboardingProgress.user_id == user_id, UserOnboardingProgress.step_name == step_name ) ) ) return result.scalars().first() except Exception as e: logger.error(f"Error getting step {step_name} for user {user_id}: {e}") return None async def upsert_user_step( self, user_id: str, step_name: str, completed: bool, step_data: Dict[str, Any] = None, auto_commit: bool = True ) -> UserOnboardingProgress: """Insert or update a user's onboarding step Args: user_id: User ID step_name: Name of the step completed: Whether the step is completed step_data: Additional data for the step auto_commit: Whether to auto-commit (set to False when used within UnitOfWork) """ try: completed_at = datetime.now(timezone.utc) if completed else None step_data = step_data or {} # Use PostgreSQL UPSERT (INSERT ... ON CONFLICT ... DO UPDATE) stmt = insert(UserOnboardingProgress).values( user_id=user_id, step_name=step_name, completed=completed, completed_at=completed_at, step_data=step_data, updated_at=datetime.now(timezone.utc) ) # On conflict, update the existing record stmt = stmt.on_conflict_do_update( index_elements=['user_id', 'step_name'], set_=dict( completed=stmt.excluded.completed, completed_at=stmt.excluded.completed_at, step_data=stmt.excluded.step_data, updated_at=stmt.excluded.updated_at ) ) # Return the updated record stmt = stmt.returning(UserOnboardingProgress) result = await self.db.execute(stmt) # Only commit if auto_commit is True (not within a UnitOfWork) if auto_commit: await self.db.commit() else: # Flush to ensure the statement is executed await self.db.flush() return result.scalars().first() except Exception as e: logger.error(f"Error upserting step {step_name} for user {user_id}: {e}") if auto_commit: await self.db.rollback() raise async def get_user_summary(self, user_id: str) -> Optional[UserOnboardingSummary]: """Get user's onboarding summary""" try: result = await self.db.execute( select(UserOnboardingSummary) .where(UserOnboardingSummary.user_id == user_id) ) return result.scalars().first() except Exception as e: logger.error(f"Error getting onboarding summary for user {user_id}: {e}") return None async def upsert_user_summary( self, user_id: str, current_step: str, next_step: Optional[str], completion_percentage: float, fully_completed: bool, steps_completed_count: str ) -> UserOnboardingSummary: """Insert or update user's onboarding summary""" try: # Use PostgreSQL UPSERT stmt = insert(UserOnboardingSummary).values( user_id=user_id, current_step=current_step, next_step=next_step, completion_percentage=str(completion_percentage), fully_completed=fully_completed, steps_completed_count=steps_completed_count, updated_at=datetime.now(timezone.utc), last_activity_at=datetime.now(timezone.utc) ) # On conflict, update the existing record stmt = stmt.on_conflict_do_update( index_elements=['user_id'], set_=dict( current_step=stmt.excluded.current_step, next_step=stmt.excluded.next_step, completion_percentage=stmt.excluded.completion_percentage, fully_completed=stmt.excluded.fully_completed, steps_completed_count=stmt.excluded.steps_completed_count, updated_at=stmt.excluded.updated_at, last_activity_at=stmt.excluded.last_activity_at ) ) # Return the updated record stmt = stmt.returning(UserOnboardingSummary) result = await self.db.execute(stmt) await self.db.commit() return result.scalars().first() except Exception as e: logger.error(f"Error upserting summary for user {user_id}: {e}") await self.db.rollback() raise async def delete_user_progress(self, user_id: str) -> bool: """Delete all onboarding progress for a user""" try: # Delete steps await self.db.execute( delete(UserOnboardingProgress) .where(UserOnboardingProgress.user_id == user_id) ) # Delete summary await self.db.execute( delete(UserOnboardingSummary) .where(UserOnboardingSummary.user_id == user_id) ) await self.db.commit() return True except Exception as e: logger.error(f"Error deleting progress for user {user_id}: {e}") await self.db.rollback() return False async def save_step_data( self, user_id: str, step_name: str, step_data: Dict[str, Any], auto_commit: bool = True ) -> UserOnboardingProgress: """Save data for a specific step without marking it as completed Args: user_id: User ID step_name: Name of the step step_data: Data to save auto_commit: Whether to auto-commit (set to False when used within UnitOfWork) """ try: # Get existing step or create new one existing_step = await self.get_user_step(user_id, step_name) if existing_step: # Update existing step data (merge with existing data) merged_data = {**(existing_step.step_data or {}), **step_data} stmt = update(UserOnboardingProgress).where( and_( UserOnboardingProgress.user_id == user_id, UserOnboardingProgress.step_name == step_name ) ).values( step_data=merged_data, updated_at=datetime.now(timezone.utc) ).returning(UserOnboardingProgress) result = await self.db.execute(stmt) if auto_commit: await self.db.commit() else: await self.db.flush() return result.scalars().first() else: # Create new step with data but not completed return await self.upsert_user_step( user_id=user_id, step_name=step_name, completed=False, step_data=step_data, auto_commit=auto_commit ) except Exception as e: logger.error(f"Error saving step data for {step_name}, user {user_id}: {e}") if auto_commit: await self.db.rollback() raise async def get_step_data(self, user_id: str, step_name: str) -> Optional[Dict[str, Any]]: """Get data for a specific step""" try: step = await self.get_user_step(user_id, step_name) return step.step_data if step else None except Exception as e: logger.error(f"Error getting step data for {step_name}, user {user_id}: {e}") return None async def get_subscription_parameters(self, user_id: str) -> Optional[Dict[str, Any]]: """Get subscription parameters saved during onboarding for tenant creation""" try: step_data = await self.get_step_data(user_id, "user_registered") if step_data: # Extract subscription-related parameters subscription_params = { "subscription_plan": step_data.get("subscription_plan", "starter"), "billing_cycle": step_data.get("billing_cycle", "monthly"), "coupon_code": step_data.get("coupon_code"), "payment_method_id": step_data.get("payment_method_id"), "payment_customer_id": step_data.get("payment_customer_id"), "saved_at": step_data.get("saved_at") } return subscription_params return None except Exception as e: logger.error(f"Error getting subscription parameters for user {user_id}: {e}") return None async def get_completion_stats(self) -> Dict[str, Any]: """Get completion statistics across all users""" try: # Get total users with onboarding data total_result = await self.db.execute( select(UserOnboardingSummary).count() ) total_users = total_result.scalar() # Get completed users completed_result = await self.db.execute( select(UserOnboardingSummary) .where(UserOnboardingSummary.fully_completed == True) .count() ) completed_users = completed_result.scalar() return { "total_users_in_onboarding": total_users, "fully_completed_users": completed_users, "completion_rate": (completed_users / total_users * 100) if total_users > 0 else 0 } except Exception as e: logger.error(f"Error getting completion stats: {e}") return { "total_users_in_onboarding": 0, "fully_completed_users": 0, "completion_rate": 0 }