358 lines
12 KiB
Python
358 lines
12 KiB
Python
"""
|
|
Registration State Management Service
|
|
Tracks registration progress and handles state transitions
|
|
"""
|
|
|
|
import structlog
|
|
from typing import Dict, Any, Optional
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from uuid import uuid4
|
|
from shared.exceptions.registration_exceptions import (
|
|
RegistrationStateError,
|
|
InvalidStateTransitionError
|
|
)
|
|
|
|
# Configure logging
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class RegistrationState(Enum):
|
|
"""Registration process states"""
|
|
INITIATED = "initiated"
|
|
PAYMENT_VERIFICATION_PENDING = "payment_verification_pending"
|
|
PAYMENT_VERIFIED = "payment_verified"
|
|
SUBSCRIPTION_CREATED = "subscription_created"
|
|
USER_CREATED = "user_created"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
|
|
|
|
class RegistrationStateService:
|
|
"""
|
|
Registration State Management Service
|
|
Tracks and manages registration process state
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize state service"""
|
|
# In production, this would use a database
|
|
self.registration_states = {}
|
|
|
|
async def create_registration_state(
|
|
self,
|
|
email: str,
|
|
user_data: Dict[str, Any]
|
|
) -> str:
|
|
"""
|
|
Create new registration state
|
|
|
|
Args:
|
|
email: User email
|
|
user_data: Registration data
|
|
|
|
Returns:
|
|
Registration state ID
|
|
"""
|
|
try:
|
|
state_id = str(uuid4())
|
|
|
|
registration_state = {
|
|
'state_id': state_id,
|
|
'email': email,
|
|
'current_state': RegistrationState.INITIATED.value,
|
|
'created_at': datetime.now().isoformat(),
|
|
'updated_at': datetime.now().isoformat(),
|
|
'user_data': user_data,
|
|
'setup_intent_id': None,
|
|
'customer_id': None,
|
|
'subscription_id': None,
|
|
'error': None
|
|
}
|
|
|
|
self.registration_states[state_id] = registration_state
|
|
|
|
logger.info("Registration state created",
|
|
state_id=state_id,
|
|
email=email,
|
|
current_state=RegistrationState.INITIATED.value)
|
|
|
|
return state_id
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to create registration state",
|
|
error=str(e),
|
|
email=email,
|
|
exc_info=True)
|
|
raise RegistrationStateError(f"State creation failed: {str(e)}") from e
|
|
|
|
async def transition_state(
|
|
self,
|
|
state_id: str,
|
|
new_state: RegistrationState,
|
|
context: Optional[Dict[str, Any]] = None
|
|
) -> None:
|
|
"""
|
|
Transition registration to new state with validation
|
|
|
|
Args:
|
|
state_id: Registration state ID
|
|
new_state: New state to transition to
|
|
context: Additional context data
|
|
|
|
Raises:
|
|
InvalidStateTransitionError: If transition is invalid
|
|
RegistrationStateError: If transition fails
|
|
"""
|
|
try:
|
|
if state_id not in self.registration_states:
|
|
raise RegistrationStateError(f"Registration state {state_id} not found")
|
|
|
|
current_state = self.registration_states[state_id]['current_state']
|
|
|
|
# Validate state transition
|
|
if not self._is_valid_transition(current_state, new_state.value):
|
|
raise InvalidStateTransitionError(
|
|
f"Invalid transition from {current_state} to {new_state.value}"
|
|
)
|
|
|
|
# Update state
|
|
self.registration_states[state_id]['current_state'] = new_state.value
|
|
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
|
|
|
|
# Update context data
|
|
if context:
|
|
self.registration_states[state_id].update(context)
|
|
|
|
logger.info("Registration state transitioned",
|
|
state_id=state_id,
|
|
from_state=current_state,
|
|
to_state=new_state.value)
|
|
|
|
except InvalidStateTransitionError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("State transition failed",
|
|
error=str(e),
|
|
state_id=state_id,
|
|
from_state=current_state,
|
|
to_state=new_state.value,
|
|
exc_info=True)
|
|
raise RegistrationStateError(f"State transition failed: {str(e)}") from e
|
|
|
|
async def update_state_context(
|
|
self,
|
|
state_id: str,
|
|
context: Dict[str, Any]
|
|
) -> None:
|
|
"""
|
|
Update state context data
|
|
|
|
Args:
|
|
state_id: Registration state ID
|
|
context: Context data to update
|
|
|
|
Raises:
|
|
RegistrationStateError: If update fails
|
|
"""
|
|
try:
|
|
if state_id not in self.registration_states:
|
|
raise RegistrationStateError(f"Registration state {state_id} not found")
|
|
|
|
self.registration_states[state_id].update(context)
|
|
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
|
|
|
|
logger.debug("Registration state context updated",
|
|
state_id=state_id,
|
|
context_keys=list(context.keys()))
|
|
|
|
except Exception as e:
|
|
logger.error("State context update failed",
|
|
error=str(e),
|
|
state_id=state_id,
|
|
exc_info=True)
|
|
raise RegistrationStateError(f"State context update failed: {str(e)}") from e
|
|
|
|
async def get_registration_state(
|
|
self,
|
|
state_id: str
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Get registration state by ID
|
|
|
|
Args:
|
|
state_id: Registration state ID
|
|
|
|
Returns:
|
|
Registration state data
|
|
|
|
Raises:
|
|
RegistrationStateError: If state not found
|
|
"""
|
|
try:
|
|
if state_id not in self.registration_states:
|
|
raise RegistrationStateError(f"Registration state {state_id} not found")
|
|
|
|
return self.registration_states[state_id]
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get registration state",
|
|
error=str(e),
|
|
state_id=state_id,
|
|
exc_info=True)
|
|
raise RegistrationStateError(f"State retrieval failed: {str(e)}") from e
|
|
|
|
async def rollback_state(
|
|
self,
|
|
state_id: str,
|
|
target_state: RegistrationState
|
|
) -> None:
|
|
"""
|
|
Rollback registration to previous state
|
|
|
|
Args:
|
|
state_id: Registration state ID
|
|
target_state: State to rollback to
|
|
|
|
Raises:
|
|
RegistrationStateError: If rollback fails
|
|
"""
|
|
try:
|
|
if state_id not in self.registration_states:
|
|
raise RegistrationStateError(f"Registration state {state_id} not found")
|
|
|
|
current_state = self.registration_states[state_id]['current_state']
|
|
|
|
# Only allow rollback to earlier states
|
|
if not self._can_rollback(current_state, target_state.value):
|
|
raise InvalidStateTransitionError(
|
|
f"Cannot rollback from {current_state} to {target_state.value}"
|
|
)
|
|
|
|
# Update state
|
|
self.registration_states[state_id]['current_state'] = target_state.value
|
|
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
|
|
self.registration_states[state_id]['error'] = "Registration rolled back"
|
|
|
|
logger.warning("Registration state rolled back",
|
|
state_id=state_id,
|
|
from_state=current_state,
|
|
to_state=target_state.value)
|
|
|
|
except InvalidStateTransitionError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("State rollback failed",
|
|
error=str(e),
|
|
state_id=state_id,
|
|
from_state=current_state,
|
|
to_state=target_state.value,
|
|
exc_info=True)
|
|
raise RegistrationStateError(f"State rollback failed: {str(e)}") from e
|
|
|
|
async def mark_registration_failed(
|
|
self,
|
|
state_id: str,
|
|
error: str
|
|
) -> None:
|
|
"""
|
|
Mark registration as failed
|
|
|
|
Args:
|
|
state_id: Registration state ID
|
|
error: Error message
|
|
|
|
Raises:
|
|
RegistrationStateError: If operation fails
|
|
"""
|
|
try:
|
|
if state_id not in self.registration_states:
|
|
raise RegistrationStateError(f"Registration state {state_id} not found")
|
|
|
|
self.registration_states[state_id]['current_state'] = RegistrationState.FAILED.value
|
|
self.registration_states[state_id]['error'] = error
|
|
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
|
|
|
|
logger.error("Registration marked as failed",
|
|
state_id=state_id,
|
|
error=error)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to mark registration as failed",
|
|
error=str(e),
|
|
state_id=state_id,
|
|
exc_info=True)
|
|
raise RegistrationStateError(f"Mark failed operation failed: {str(e)}") from e
|
|
|
|
def _is_valid_transition(self, current_state: str, new_state: str) -> bool:
|
|
"""
|
|
Validate state transition
|
|
|
|
Args:
|
|
current_state: Current state
|
|
new_state: New state
|
|
|
|
Returns:
|
|
True if transition is valid
|
|
"""
|
|
# Define valid state transitions
|
|
valid_transitions = {
|
|
RegistrationState.INITIATED.value: [
|
|
RegistrationState.PAYMENT_VERIFICATION_PENDING.value,
|
|
RegistrationState.FAILED.value
|
|
],
|
|
RegistrationState.PAYMENT_VERIFICATION_PENDING.value: [
|
|
RegistrationState.PAYMENT_VERIFIED.value,
|
|
RegistrationState.FAILED.value
|
|
],
|
|
RegistrationState.PAYMENT_VERIFIED.value: [
|
|
RegistrationState.SUBSCRIPTION_CREATED.value,
|
|
RegistrationState.FAILED.value
|
|
],
|
|
RegistrationState.SUBSCRIPTION_CREATED.value: [
|
|
RegistrationState.USER_CREATED.value,
|
|
RegistrationState.FAILED.value
|
|
],
|
|
RegistrationState.USER_CREATED.value: [
|
|
RegistrationState.COMPLETED.value,
|
|
RegistrationState.FAILED.value
|
|
],
|
|
RegistrationState.COMPLETED.value: [],
|
|
RegistrationState.FAILED.value: []
|
|
}
|
|
|
|
return new_state in valid_transitions.get(current_state, [])
|
|
|
|
def _can_rollback(self, current_state: str, target_state: str) -> bool:
|
|
"""
|
|
Check if rollback to target state is allowed
|
|
|
|
Args:
|
|
current_state: Current state
|
|
target_state: Target state for rollback
|
|
|
|
Returns:
|
|
True if rollback is allowed
|
|
"""
|
|
# Define state order for rollback validation
|
|
state_order = [
|
|
RegistrationState.INITIATED.value,
|
|
RegistrationState.PAYMENT_VERIFICATION_PENDING.value,
|
|
RegistrationState.PAYMENT_VERIFIED.value,
|
|
RegistrationState.SUBSCRIPTION_CREATED.value,
|
|
RegistrationState.USER_CREATED.value,
|
|
RegistrationState.COMPLETED.value
|
|
]
|
|
|
|
try:
|
|
current_index = state_order.index(current_state)
|
|
target_index = state_order.index(target_state)
|
|
|
|
# Can only rollback to earlier states
|
|
return target_index < current_index
|
|
except ValueError:
|
|
return False
|
|
|
|
|
|
# Singleton instance for dependency injection
|
|
registration_state_service = RegistrationStateService() |