Add subcription feature 3
This commit is contained in:
@@ -211,7 +211,8 @@ async def clone_demo_data(
|
||||
subscription_data.get('cancellation_effective_date'),
|
||||
session_time,
|
||||
"cancellation_effective_date"
|
||||
)
|
||||
),
|
||||
is_tenant_linked=True # Required for check constraint when tenant_id is set
|
||||
)
|
||||
|
||||
db.add(subscription)
|
||||
@@ -245,7 +246,8 @@ async def clone_demo_data(
|
||||
"api_access": True,
|
||||
"priority_support": True
|
||||
},
|
||||
next_billing_date=datetime.now(timezone.utc) + timedelta(days=90)
|
||||
next_billing_date=datetime.now(timezone.utc) + timedelta(days=90),
|
||||
is_tenant_linked=True # Required for check constraint when tenant_id is set
|
||||
)
|
||||
|
||||
db.add(subscription)
|
||||
@@ -323,7 +325,8 @@ async def clone_demo_data(
|
||||
max_users=-1, # Unlimited for demo
|
||||
max_locations=max_locations,
|
||||
max_products=-1, # Unlimited for demo
|
||||
features={}
|
||||
features={},
|
||||
is_tenant_linked=True # Required for check constraint when tenant_id is set
|
||||
)
|
||||
db.add(demo_subscription)
|
||||
|
||||
@@ -699,7 +702,8 @@ async def create_child_outlet(
|
||||
max_users=10, # Demo limits
|
||||
max_locations=1, # Single location for outlet
|
||||
max_products=200,
|
||||
features={}
|
||||
features={},
|
||||
is_tenant_linked=True # Required for check constraint when tenant_id is set
|
||||
)
|
||||
db.add(child_subscription)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1122,6 +1122,88 @@ async def upgrade_subscription_plan(
|
||||
detail="Failed to upgrade subscription plan"
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# REGISTRATION ORCHESTRATION (SECURE 3DS FLOW)
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/registration-payment-setup")
|
||||
async def registration_payment_setup(
|
||||
user_data: Dict[str, Any],
|
||||
payment_service: PaymentService = Depends(get_payment_service)
|
||||
):
|
||||
"""
|
||||
Orchestrate initial payment setup for a new registration.
|
||||
This creates the customer and SetupIntent for 3DS verification.
|
||||
"""
|
||||
try:
|
||||
logger.info("Orchestrating registration payment setup",
|
||||
email=user_data.get('email'))
|
||||
|
||||
result = await payment_service.complete_registration_payment_flow(user_data)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error("Registration payment setup orchestration failed",
|
||||
email=user_data.get('email'),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Registration payment setup failed: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/verify-and-complete-registration")
|
||||
async def verify_and_complete_registration(
|
||||
registration_data: Dict[str, Any],
|
||||
payment_service: PaymentService = Depends(get_payment_service),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""
|
||||
Final step: Verify SetupIntent and create subscription + tenant.
|
||||
Called after frontend confirms 3DS with Stripe.
|
||||
"""
|
||||
try:
|
||||
setup_intent_id = registration_data.get('setup_intent_id')
|
||||
user_data = registration_data.get('user_data', {})
|
||||
|
||||
logger.info("Verifying and completing registration",
|
||||
email=user_data.get('email'),
|
||||
setup_intent_id=setup_intent_id)
|
||||
|
||||
# 1. Complete subscription creation after verification
|
||||
payment_result = await payment_service.complete_subscription_after_verification(
|
||||
setup_intent_id, user_data
|
||||
)
|
||||
|
||||
# 2. Create the bakery/tenant record
|
||||
# Note: In a real flow, we'd use the payment result (customer_id, subscription_id)
|
||||
# to properly link the new tenant.
|
||||
bakery_registration = BakeryRegistration(
|
||||
name=user_data.get('name', f"{user_data.get('full_name')}'s Bakery"),
|
||||
subdomain=user_data.get('subdomain', user_data.get('email').split('@')[0]),
|
||||
business_type=user_data.get('business_type', 'bakery'),
|
||||
link_existing_subscription=True,
|
||||
subscription_id=payment_result['subscription_id']
|
||||
)
|
||||
|
||||
# We need the user_id from the auth service call
|
||||
user_id = user_data.get('user_id')
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=400, detail="Missing user_id in registration data")
|
||||
|
||||
tenant_result = await tenant_service.create_bakery(bakery_registration, user_id)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tenant": tenant_result,
|
||||
"payment": payment_result
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Verify and complete registration failed",
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Registration completion failed: {str(e)}"
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# PAYMENT OPERATIONS
|
||||
# ============================================================================
|
||||
@@ -1244,7 +1326,7 @@ async def register_with_subscription(
|
||||
**result
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Failed to register with subscription", error=str(e))
|
||||
logger.error(f"Failed to register with subscription: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to register with subscription"
|
||||
|
||||
@@ -88,7 +88,10 @@ async def stripe_webhook(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error processing Stripe webhook", error=str(e), exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Webhook processing error"
|
||||
)
|
||||
# Return 200 OK even on processing errors to prevent Stripe retries
|
||||
# Only return 4xx for signature verification failures
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Webhook processing error",
|
||||
"details": str(e)
|
||||
}
|
||||
|
||||
@@ -67,6 +67,14 @@ class Tenant(Base):
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
# 3D Secure (3DS) tracking
|
||||
threeds_authentication_required = Column(Boolean, default=False)
|
||||
threeds_authentication_required_at = Column(DateTime(timezone=True), nullable=True)
|
||||
threeds_authentication_completed = Column(Boolean, default=False)
|
||||
threeds_authentication_completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
last_threeds_setup_intent_id = Column(String(255), nullable=True)
|
||||
threeds_action_type = Column(String(100), nullable=True)
|
||||
|
||||
# Relationships - only within tenant service
|
||||
members = relationship("TenantMember", back_populates="tenant", cascade="all, delete-orphan")
|
||||
subscriptions = relationship("Subscription", back_populates="tenant", cascade="all, delete-orphan")
|
||||
@@ -187,6 +195,14 @@ class Subscription(Base):
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
# 3D Secure (3DS) tracking
|
||||
threeds_authentication_required = Column(Boolean, default=False)
|
||||
threeds_authentication_required_at = Column(DateTime(timezone=True), nullable=True)
|
||||
threeds_authentication_completed = Column(Boolean, default=False)
|
||||
threeds_authentication_completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
last_threeds_setup_intent_id = Column(String(255), nullable=True)
|
||||
threeds_action_type = Column(String(100), nullable=True)
|
||||
|
||||
# Relationships
|
||||
tenant = relationship("Tenant")
|
||||
|
||||
@@ -82,15 +82,29 @@ class SubscriptionRepository(TenantBaseRepository):
|
||||
else:
|
||||
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
|
||||
|
||||
# Check if subscription with this subscription_id already exists to prevent duplicates
|
||||
if subscription_data.get('subscription_id'):
|
||||
existing_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
|
||||
if existing_subscription:
|
||||
# Update the existing subscription instead of creating a duplicate
|
||||
updated_subscription = await self.update(str(existing_subscription.id), subscription_data)
|
||||
|
||||
logger.info("Existing subscription updated",
|
||||
subscription_id=subscription_data['subscription_id'],
|
||||
tenant_id=subscription_data.get('tenant_id'),
|
||||
plan=subscription_data.get('plan'))
|
||||
|
||||
return updated_subscription
|
||||
|
||||
# Create subscription
|
||||
subscription = await self.create(subscription_data)
|
||||
|
||||
|
||||
logger.info("Subscription created successfully",
|
||||
subscription_id=subscription.id,
|
||||
tenant_id=subscription.tenant_id,
|
||||
plan=subscription.plan,
|
||||
monthly_price=subscription.monthly_price)
|
||||
|
||||
|
||||
return subscription
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
@@ -514,7 +528,8 @@ class SubscriptionRepository(TenantBaseRepository):
|
||||
"""Create a subscription not linked to any tenant (for registration flow)"""
|
||||
try:
|
||||
# Validate required data for tenant-independent subscription
|
||||
required_fields = ["user_id", "plan", "subscription_id", "customer_id"]
|
||||
# user_id may not exist during registration, so validate other required fields
|
||||
required_fields = ["plan", "subscription_id", "customer_id"]
|
||||
validation_result = self._validate_tenant_data(subscription_data, required_fields)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
@@ -567,16 +582,41 @@ class SubscriptionRepository(TenantBaseRepository):
|
||||
else:
|
||||
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
|
||||
|
||||
# Create tenant-independent subscription
|
||||
subscription = await self.create(subscription_data)
|
||||
|
||||
logger.info("Tenant-independent subscription created successfully",
|
||||
subscription_id=subscription.id,
|
||||
user_id=subscription.user_id,
|
||||
plan=subscription.plan,
|
||||
monthly_price=subscription.monthly_price)
|
||||
|
||||
return subscription
|
||||
# Check if subscription with this subscription_id already exists
|
||||
existing_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
|
||||
if existing_subscription:
|
||||
# Update the existing subscription instead of creating a duplicate
|
||||
updated_subscription = await self.update(str(existing_subscription.id), subscription_data)
|
||||
|
||||
logger.info("Existing tenant-independent subscription updated",
|
||||
subscription_id=subscription_data['subscription_id'],
|
||||
user_id=subscription_data.get('user_id'),
|
||||
plan=subscription_data.get('plan'))
|
||||
|
||||
return updated_subscription
|
||||
else:
|
||||
# Create new subscription, but handle potential duplicate errors
|
||||
try:
|
||||
subscription = await self.create(subscription_data)
|
||||
|
||||
logger.info("Tenant-independent subscription created successfully",
|
||||
subscription_id=subscription.id,
|
||||
user_id=subscription.user_id,
|
||||
plan=subscription.plan,
|
||||
monthly_price=subscription.monthly_price)
|
||||
|
||||
return subscription
|
||||
except DuplicateRecordError:
|
||||
# Another process may have created the subscription between our check and create
|
||||
# Try to get the existing subscription and return it
|
||||
final_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
|
||||
if final_subscription:
|
||||
logger.info("Race condition detected: subscription already created by another process",
|
||||
subscription_id=subscription_data['subscription_id'])
|
||||
return final_subscription
|
||||
else:
|
||||
# This shouldn't happen, but re-raise the error if we can't find it
|
||||
raise
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
@@ -700,3 +740,29 @@ class SubscriptionRepository(TenantBaseRepository):
|
||||
logger.error("Failed to cleanup orphaned subscriptions",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
async def get_by_customer_id(self, customer_id: str) -> List[Subscription]:
|
||||
"""
|
||||
Get subscriptions by Stripe customer ID
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
|
||||
Returns:
|
||||
List of Subscription objects
|
||||
"""
|
||||
try:
|
||||
query = select(Subscription).where(Subscription.customer_id == customer_id)
|
||||
result = await self.session.execute(query)
|
||||
subscriptions = result.scalars().all()
|
||||
|
||||
logger.debug("Found subscriptions by customer_id",
|
||||
customer_id=customer_id,
|
||||
count=len(subscriptions))
|
||||
return subscriptions
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting subscriptions by customer_id",
|
||||
customer_id=customer_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get subscriptions by customer_id: {str(e)}")
|
||||
|
||||
@@ -11,7 +11,7 @@ import structlog
|
||||
import uuid
|
||||
|
||||
from .base import TenantBaseRepository
|
||||
from app.models.tenants import Tenant
|
||||
from app.models.tenants import Tenant, Subscription
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
@@ -570,3 +570,39 @@ class TenantRepository(TenantBaseRepository):
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get enterprise demo tenants: {str(e)}")
|
||||
|
||||
async def get_by_customer_id(self, customer_id: str) -> Optional[Tenant]:
|
||||
"""
|
||||
Get tenant by Stripe customer ID
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
|
||||
Returns:
|
||||
Tenant object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Find tenant by joining with subscriptions table
|
||||
# Tenant doesn't have customer_id directly, so we need to find via subscription
|
||||
query = select(Tenant).join(
|
||||
Subscription, Subscription.tenant_id == Tenant.id
|
||||
).where(Subscription.customer_id == customer_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if tenant:
|
||||
logger.debug("Found tenant by customer_id",
|
||||
customer_id=customer_id,
|
||||
tenant_id=str(tenant.id))
|
||||
return tenant
|
||||
else:
|
||||
logger.debug("No tenant found for customer_id",
|
||||
customer_id=customer_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting tenant by customer_id",
|
||||
customer_id=customer_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenant by customer_id: {str(e)}")
|
||||
|
||||
@@ -43,7 +43,7 @@ class NetworkAlertsService:
|
||||
return enriched_children
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get child tenants", parent_id=parent_id, error=str(e))
|
||||
logger.error(f"Failed to get child tenants, parent_id={parent_id}, error={str(e)}")
|
||||
raise Exception(f"Failed to get child tenants: {str(e)}")
|
||||
|
||||
async def get_alerts_for_tenant(self, tenant_id: str) -> List[Dict[str, Any]]:
|
||||
@@ -80,7 +80,7 @@ class NetworkAlertsService:
|
||||
return simulated_alerts
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get alerts for tenant", tenant_id=tenant_id, error=str(e))
|
||||
logger.error(f"Failed to get alerts for tenant, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise Exception(f"Failed to get alerts: {str(e)}")
|
||||
|
||||
async def get_network_alerts(self, parent_id: str) -> List[Dict[str, Any]]:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
358
services/tenant/app/services/registration_state_service.py
Normal file
358
services/tenant/app/services/registration_state_service.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
Registration State Management Service
|
||||
Tracks registration progress and handles state transitions
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from uuid import uuid4
|
||||
from shared.exceptions.registration_exceptions import (
|
||||
RegistrationStateError,
|
||||
InvalidStateTransitionError
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RegistrationState(Enum):
|
||||
"""Registration process states"""
|
||||
INITIATED = "initiated"
|
||||
PAYMENT_VERIFICATION_PENDING = "payment_verification_pending"
|
||||
PAYMENT_VERIFIED = "payment_verified"
|
||||
SUBSCRIPTION_CREATED = "subscription_created"
|
||||
USER_CREATED = "user_created"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class RegistrationStateService:
|
||||
"""
|
||||
Registration State Management Service
|
||||
Tracks and manages registration process state
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize state service"""
|
||||
# In production, this would use a database
|
||||
self.registration_states = {}
|
||||
|
||||
async def create_registration_state(
|
||||
self,
|
||||
email: str,
|
||||
user_data: Dict[str, Any]
|
||||
) -> str:
|
||||
"""
|
||||
Create new registration state
|
||||
|
||||
Args:
|
||||
email: User email
|
||||
user_data: Registration data
|
||||
|
||||
Returns:
|
||||
Registration state ID
|
||||
"""
|
||||
try:
|
||||
state_id = str(uuid4())
|
||||
|
||||
registration_state = {
|
||||
'state_id': state_id,
|
||||
'email': email,
|
||||
'current_state': RegistrationState.INITIATED.value,
|
||||
'created_at': datetime.now().isoformat(),
|
||||
'updated_at': datetime.now().isoformat(),
|
||||
'user_data': user_data,
|
||||
'setup_intent_id': None,
|
||||
'customer_id': None,
|
||||
'subscription_id': None,
|
||||
'error': None
|
||||
}
|
||||
|
||||
self.registration_states[state_id] = registration_state
|
||||
|
||||
logger.info("Registration state created",
|
||||
state_id=state_id,
|
||||
email=email,
|
||||
current_state=RegistrationState.INITIATED.value)
|
||||
|
||||
return state_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create registration state",
|
||||
error=str(e),
|
||||
email=email,
|
||||
exc_info=True)
|
||||
raise RegistrationStateError(f"State creation failed: {str(e)}") from e
|
||||
|
||||
async def transition_state(
|
||||
self,
|
||||
state_id: str,
|
||||
new_state: RegistrationState,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Transition registration to new state with validation
|
||||
|
||||
Args:
|
||||
state_id: Registration state ID
|
||||
new_state: New state to transition to
|
||||
context: Additional context data
|
||||
|
||||
Raises:
|
||||
InvalidStateTransitionError: If transition is invalid
|
||||
RegistrationStateError: If transition fails
|
||||
"""
|
||||
try:
|
||||
if state_id not in self.registration_states:
|
||||
raise RegistrationStateError(f"Registration state {state_id} not found")
|
||||
|
||||
current_state = self.registration_states[state_id]['current_state']
|
||||
|
||||
# Validate state transition
|
||||
if not self._is_valid_transition(current_state, new_state.value):
|
||||
raise InvalidStateTransitionError(
|
||||
f"Invalid transition from {current_state} to {new_state.value}"
|
||||
)
|
||||
|
||||
# Update state
|
||||
self.registration_states[state_id]['current_state'] = new_state.value
|
||||
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
|
||||
|
||||
# Update context data
|
||||
if context:
|
||||
self.registration_states[state_id].update(context)
|
||||
|
||||
logger.info("Registration state transitioned",
|
||||
state_id=state_id,
|
||||
from_state=current_state,
|
||||
to_state=new_state.value)
|
||||
|
||||
except InvalidStateTransitionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("State transition failed",
|
||||
error=str(e),
|
||||
state_id=state_id,
|
||||
from_state=current_state,
|
||||
to_state=new_state.value,
|
||||
exc_info=True)
|
||||
raise RegistrationStateError(f"State transition failed: {str(e)}") from e
|
||||
|
||||
async def update_state_context(
|
||||
self,
|
||||
state_id: str,
|
||||
context: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Update state context data
|
||||
|
||||
Args:
|
||||
state_id: Registration state ID
|
||||
context: Context data to update
|
||||
|
||||
Raises:
|
||||
RegistrationStateError: If update fails
|
||||
"""
|
||||
try:
|
||||
if state_id not in self.registration_states:
|
||||
raise RegistrationStateError(f"Registration state {state_id} not found")
|
||||
|
||||
self.registration_states[state_id].update(context)
|
||||
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
|
||||
|
||||
logger.debug("Registration state context updated",
|
||||
state_id=state_id,
|
||||
context_keys=list(context.keys()))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("State context update failed",
|
||||
error=str(e),
|
||||
state_id=state_id,
|
||||
exc_info=True)
|
||||
raise RegistrationStateError(f"State context update failed: {str(e)}") from e
|
||||
|
||||
async def get_registration_state(
|
||||
self,
|
||||
state_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get registration state by ID
|
||||
|
||||
Args:
|
||||
state_id: Registration state ID
|
||||
|
||||
Returns:
|
||||
Registration state data
|
||||
|
||||
Raises:
|
||||
RegistrationStateError: If state not found
|
||||
"""
|
||||
try:
|
||||
if state_id not in self.registration_states:
|
||||
raise RegistrationStateError(f"Registration state {state_id} not found")
|
||||
|
||||
return self.registration_states[state_id]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get registration state",
|
||||
error=str(e),
|
||||
state_id=state_id,
|
||||
exc_info=True)
|
||||
raise RegistrationStateError(f"State retrieval failed: {str(e)}") from e
|
||||
|
||||
async def rollback_state(
|
||||
self,
|
||||
state_id: str,
|
||||
target_state: RegistrationState
|
||||
) -> None:
|
||||
"""
|
||||
Rollback registration to previous state
|
||||
|
||||
Args:
|
||||
state_id: Registration state ID
|
||||
target_state: State to rollback to
|
||||
|
||||
Raises:
|
||||
RegistrationStateError: If rollback fails
|
||||
"""
|
||||
try:
|
||||
if state_id not in self.registration_states:
|
||||
raise RegistrationStateError(f"Registration state {state_id} not found")
|
||||
|
||||
current_state = self.registration_states[state_id]['current_state']
|
||||
|
||||
# Only allow rollback to earlier states
|
||||
if not self._can_rollback(current_state, target_state.value):
|
||||
raise InvalidStateTransitionError(
|
||||
f"Cannot rollback from {current_state} to {target_state.value}"
|
||||
)
|
||||
|
||||
# Update state
|
||||
self.registration_states[state_id]['current_state'] = target_state.value
|
||||
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
|
||||
self.registration_states[state_id]['error'] = "Registration rolled back"
|
||||
|
||||
logger.warning("Registration state rolled back",
|
||||
state_id=state_id,
|
||||
from_state=current_state,
|
||||
to_state=target_state.value)
|
||||
|
||||
except InvalidStateTransitionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("State rollback failed",
|
||||
error=str(e),
|
||||
state_id=state_id,
|
||||
from_state=current_state,
|
||||
to_state=target_state.value,
|
||||
exc_info=True)
|
||||
raise RegistrationStateError(f"State rollback failed: {str(e)}") from e
|
||||
|
||||
async def mark_registration_failed(
|
||||
self,
|
||||
state_id: str,
|
||||
error: str
|
||||
) -> None:
|
||||
"""
|
||||
Mark registration as failed
|
||||
|
||||
Args:
|
||||
state_id: Registration state ID
|
||||
error: Error message
|
||||
|
||||
Raises:
|
||||
RegistrationStateError: If operation fails
|
||||
"""
|
||||
try:
|
||||
if state_id not in self.registration_states:
|
||||
raise RegistrationStateError(f"Registration state {state_id} not found")
|
||||
|
||||
self.registration_states[state_id]['current_state'] = RegistrationState.FAILED.value
|
||||
self.registration_states[state_id]['error'] = error
|
||||
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
|
||||
|
||||
logger.error("Registration marked as failed",
|
||||
state_id=state_id,
|
||||
error=error)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to mark registration as failed",
|
||||
error=str(e),
|
||||
state_id=state_id,
|
||||
exc_info=True)
|
||||
raise RegistrationStateError(f"Mark failed operation failed: {str(e)}") from e
|
||||
|
||||
def _is_valid_transition(self, current_state: str, new_state: str) -> bool:
|
||||
"""
|
||||
Validate state transition
|
||||
|
||||
Args:
|
||||
current_state: Current state
|
||||
new_state: New state
|
||||
|
||||
Returns:
|
||||
True if transition is valid
|
||||
"""
|
||||
# Define valid state transitions
|
||||
valid_transitions = {
|
||||
RegistrationState.INITIATED.value: [
|
||||
RegistrationState.PAYMENT_VERIFICATION_PENDING.value,
|
||||
RegistrationState.FAILED.value
|
||||
],
|
||||
RegistrationState.PAYMENT_VERIFICATION_PENDING.value: [
|
||||
RegistrationState.PAYMENT_VERIFIED.value,
|
||||
RegistrationState.FAILED.value
|
||||
],
|
||||
RegistrationState.PAYMENT_VERIFIED.value: [
|
||||
RegistrationState.SUBSCRIPTION_CREATED.value,
|
||||
RegistrationState.FAILED.value
|
||||
],
|
||||
RegistrationState.SUBSCRIPTION_CREATED.value: [
|
||||
RegistrationState.USER_CREATED.value,
|
||||
RegistrationState.FAILED.value
|
||||
],
|
||||
RegistrationState.USER_CREATED.value: [
|
||||
RegistrationState.COMPLETED.value,
|
||||
RegistrationState.FAILED.value
|
||||
],
|
||||
RegistrationState.COMPLETED.value: [],
|
||||
RegistrationState.FAILED.value: []
|
||||
}
|
||||
|
||||
return new_state in valid_transitions.get(current_state, [])
|
||||
|
||||
def _can_rollback(self, current_state: str, target_state: str) -> bool:
|
||||
"""
|
||||
Check if rollback to target state is allowed
|
||||
|
||||
Args:
|
||||
current_state: Current state
|
||||
target_state: Target state for rollback
|
||||
|
||||
Returns:
|
||||
True if rollback is allowed
|
||||
"""
|
||||
# Define state order for rollback validation
|
||||
state_order = [
|
||||
RegistrationState.INITIATED.value,
|
||||
RegistrationState.PAYMENT_VERIFICATION_PENDING.value,
|
||||
RegistrationState.PAYMENT_VERIFIED.value,
|
||||
RegistrationState.SUBSCRIPTION_CREATED.value,
|
||||
RegistrationState.USER_CREATED.value,
|
||||
RegistrationState.COMPLETED.value
|
||||
]
|
||||
|
||||
try:
|
||||
current_index = state_order.index(current_state)
|
||||
target_index = state_order.index(target_state)
|
||||
|
||||
# Can only rollback to earlier states
|
||||
return target_index < current_index
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
# Singleton instance for dependency injection
|
||||
registration_state_service = RegistrationStateService()
|
||||
@@ -17,6 +17,8 @@ from app.services.tenant_service import EnhancedTenantService
|
||||
from app.core.config import settings
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.exceptions.payment_exceptions import SubscriptionUpdateFailed
|
||||
from shared.exceptions.subscription_exceptions import SubscriptionNotFound
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -69,7 +71,10 @@ class SubscriptionOrchestrationService:
|
||||
logger.info("Creating customer in payment provider",
|
||||
tenant_id=tenant_id, email=user_data.get('email'))
|
||||
|
||||
customer = await self.payment_service.create_customer(user_data)
|
||||
email = user_data.get('email')
|
||||
name = f"{user_data.get('first_name', '')} {user_data.get('last_name', '')}".strip()
|
||||
metadata = None
|
||||
customer = await self.payment_service.create_customer(email, name, metadata)
|
||||
logger.info("Customer created successfully",
|
||||
customer_id=customer.id, tenant_id=tenant_id)
|
||||
|
||||
@@ -106,9 +111,12 @@ class SubscriptionOrchestrationService:
|
||||
plan_id=plan_id,
|
||||
trial_period_days=trial_period_days)
|
||||
|
||||
stripe_subscription = await self.payment_service.create_payment_subscription(
|
||||
# Get the Stripe price ID for this plan
|
||||
price_id = self.payment_service._get_stripe_price_id(plan_id, billing_interval)
|
||||
|
||||
stripe_subscription = await self.payment_service.create_subscription_with_verified_payment(
|
||||
customer.id,
|
||||
plan_id,
|
||||
price_id,
|
||||
payment_method_id,
|
||||
trial_period_days if trial_period_days > 0 else None,
|
||||
billing_interval
|
||||
@@ -232,7 +240,10 @@ class SubscriptionOrchestrationService:
|
||||
user_id=user_data.get('user_id'),
|
||||
email=user_data.get('email'))
|
||||
|
||||
customer = await self.payment_service.create_customer(user_data)
|
||||
email = user_data.get('email')
|
||||
name = f"{user_data.get('first_name', '')} {user_data.get('last_name', '')}".strip()
|
||||
metadata = None
|
||||
customer = await self.payment_service.create_customer(email, name, metadata)
|
||||
logger.info("Customer created successfully",
|
||||
customer_id=customer.id,
|
||||
user_id=user_data.get('user_id'))
|
||||
@@ -271,9 +282,12 @@ class SubscriptionOrchestrationService:
|
||||
plan_id=plan_id,
|
||||
trial_period_days=trial_period_days)
|
||||
|
||||
subscription_result = await self.payment_service.create_payment_subscription(
|
||||
# Get the Stripe price ID for this plan
|
||||
price_id = self.payment_service._get_stripe_price_id(plan_id, billing_interval)
|
||||
|
||||
subscription_result = await self.payment_service.create_subscription_with_verified_payment(
|
||||
customer.id,
|
||||
plan_id,
|
||||
price_id,
|
||||
payment_method_id,
|
||||
trial_period_days if trial_period_days > 0 else None,
|
||||
billing_interval
|
||||
@@ -302,9 +316,25 @@ class SubscriptionOrchestrationService:
|
||||
}
|
||||
|
||||
# Extract subscription object from result
|
||||
# Result can be either a dict with 'subscription' key or the subscription object directly
|
||||
if isinstance(subscription_result, dict) and 'subscription' in subscription_result:
|
||||
stripe_subscription = subscription_result['subscription']
|
||||
# Result can be either:
|
||||
# 1. A dict with 'subscription' key containing an object
|
||||
# 2. A dict with subscription fields directly (subscription_id, status, etc.)
|
||||
# 3. A subscription object directly
|
||||
if isinstance(subscription_result, dict):
|
||||
if 'subscription' in subscription_result:
|
||||
stripe_subscription = subscription_result['subscription']
|
||||
elif 'subscription_id' in subscription_result:
|
||||
# Create a simple object-like wrapper for dict results
|
||||
class SubscriptionWrapper:
|
||||
def __init__(self, data: dict):
|
||||
self.id = data.get('subscription_id')
|
||||
self.status = data.get('status')
|
||||
self.current_period_start = data.get('current_period_start')
|
||||
self.current_period_end = data.get('current_period_end')
|
||||
self.customer = data.get('customer_id')
|
||||
stripe_subscription = SubscriptionWrapper(subscription_result)
|
||||
else:
|
||||
stripe_subscription = subscription_result
|
||||
else:
|
||||
stripe_subscription = subscription_result
|
||||
|
||||
@@ -980,7 +1010,7 @@ class SubscriptionOrchestrationService:
|
||||
status=status)
|
||||
|
||||
# Find tenant by subscription
|
||||
subscription = await self.subscription_service.get_subscription_by_stripe_id(subscription_id)
|
||||
subscription = await self.subscription_service.get_subscription_by_provider_id(subscription_id)
|
||||
|
||||
if subscription:
|
||||
# Update subscription status
|
||||
@@ -1014,7 +1044,7 @@ class SubscriptionOrchestrationService:
|
||||
subscription_id=subscription_id)
|
||||
|
||||
# Find and update subscription
|
||||
subscription = await self.subscription_service.get_subscription_by_stripe_id(subscription_id)
|
||||
subscription = await self.subscription_service.get_subscription_by_provider_id(subscription_id)
|
||||
|
||||
if subscription:
|
||||
# Cancel subscription in our system
|
||||
@@ -1618,16 +1648,14 @@ class SubscriptionOrchestrationService:
|
||||
email=user_data.get('email'))
|
||||
|
||||
# Create customer without user_id metadata
|
||||
customer_data = {
|
||||
'email': user_data.get('email'),
|
||||
'name': user_data.get('full_name'),
|
||||
'metadata': {
|
||||
'registration_flow': 'pre_user_creation',
|
||||
'timestamp': datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
email = user_data.get('email')
|
||||
name = user_data.get('full_name')
|
||||
metadata = {
|
||||
'registration_flow': 'pre_user_creation',
|
||||
'timestamp': datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
customer = await self.payment_service.create_customer(customer_data)
|
||||
customer = await self.payment_service.create_customer(email, name, metadata)
|
||||
logger.info("Payment customer created for registration",
|
||||
customer_id=customer.id,
|
||||
email=user_data.get('email'))
|
||||
@@ -1665,9 +1693,12 @@ class SubscriptionOrchestrationService:
|
||||
plan_id=plan_id,
|
||||
payment_method_id=payment_method_id)
|
||||
|
||||
subscription_result = await self.payment_service.create_payment_subscription(
|
||||
# Get the Stripe price ID for this plan
|
||||
price_id = self.payment_service._get_stripe_price_id(plan_id, billing_interval)
|
||||
|
||||
subscription_result = await self.payment_service.create_subscription_with_verified_payment(
|
||||
customer.id,
|
||||
plan_id,
|
||||
price_id,
|
||||
payment_method_id,
|
||||
trial_period_days if trial_period_days > 0 else None,
|
||||
billing_interval
|
||||
@@ -1678,14 +1709,19 @@ class SubscriptionOrchestrationService:
|
||||
logger.info("Registration payment setup requires SetupIntent confirmation",
|
||||
customer_id=customer.id,
|
||||
action_type=subscription_result.get('action_type'),
|
||||
setup_intent_id=subscription_result.get('setup_intent_id'))
|
||||
setup_intent_id=subscription_result.get('setup_intent_id'),
|
||||
subscription_id=subscription_result.get('subscription_id'))
|
||||
|
||||
# Return the SetupIntent data for frontend to handle 3DS
|
||||
# Note: subscription_id is included because for trial subscriptions,
|
||||
# the subscription is already created in 'trialing' status even though
|
||||
# the SetupIntent requires 3DS verification for future payments
|
||||
return {
|
||||
"requires_action": True,
|
||||
"action_type": subscription_result.get('action_type'),
|
||||
"action_type": subscription_result.get('action_type') or 'use_stripe_sdk',
|
||||
"client_secret": subscription_result.get('client_secret'),
|
||||
"setup_intent_id": subscription_result.get('setup_intent_id'),
|
||||
"subscription_id": subscription_result.get('subscription_id'),
|
||||
"customer_id": customer.id,
|
||||
"payment_customer_id": customer.id,
|
||||
"plan_id": plan_id,
|
||||
@@ -1764,3 +1800,154 @@ class SubscriptionOrchestrationService:
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
async def validate_plan_upgrade(
|
||||
self,
|
||||
tenant_id: str,
|
||||
new_plan: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate if a tenant can upgrade to a new plan
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
new_plan: New plan to validate upgrade to
|
||||
|
||||
Returns:
|
||||
Dictionary with validation result
|
||||
"""
|
||||
try:
|
||||
logger.info("Validating plan upgrade",
|
||||
tenant_id=tenant_id,
|
||||
new_plan=new_plan)
|
||||
|
||||
# Delegate to subscription service for validation
|
||||
can_upgrade = await self.subscription_service.validate_subscription_change(
|
||||
tenant_id,
|
||||
new_plan
|
||||
)
|
||||
|
||||
result = {
|
||||
"can_upgrade": can_upgrade,
|
||||
"tenant_id": tenant_id,
|
||||
"current_plan": None, # Would need to fetch current plan if needed
|
||||
"new_plan": new_plan
|
||||
}
|
||||
|
||||
if not can_upgrade:
|
||||
result["reason"] = "Subscription change not allowed based on current status"
|
||||
|
||||
logger.info("Plan upgrade validation completed",
|
||||
tenant_id=tenant_id,
|
||||
can_upgrade=can_upgrade)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Plan upgrade validation failed",
|
||||
tenant_id=tenant_id,
|
||||
new_plan=new_plan,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise DatabaseError(f"Failed to validate plan upgrade: {str(e)}")
|
||||
|
||||
async def get_subscriptions_by_customer_id(self, customer_id: str) -> List[Subscription]:
|
||||
"""
|
||||
Get all subscriptions for a given customer ID
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
|
||||
Returns:
|
||||
List of Subscription objects
|
||||
"""
|
||||
try:
|
||||
return await self.subscription_service.get_subscriptions_by_customer_id(customer_id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get subscriptions by customer ID",
|
||||
customer_id=customer_id,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise DatabaseError(f"Failed to get subscriptions: {str(e)}")
|
||||
|
||||
async def update_subscription_with_verified_payment(
|
||||
self,
|
||||
subscription_id: str,
|
||||
customer_id: str,
|
||||
payment_method_id: str,
|
||||
trial_period_days: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Update an existing subscription with a verified payment method
|
||||
|
||||
This is used when we already have a trial subscription and just need to
|
||||
attach the verified payment method to it.
|
||||
|
||||
Args:
|
||||
subscription_id: Stripe subscription ID
|
||||
customer_id: Stripe customer ID
|
||||
payment_method_id: Verified payment method ID
|
||||
trial_period_days: Optional trial period (for validation)
|
||||
|
||||
Returns:
|
||||
Dictionary with updated subscription details
|
||||
"""
|
||||
try:
|
||||
logger.info("Updating existing subscription with verified payment method",
|
||||
subscription_id=subscription_id,
|
||||
customer_id=customer_id,
|
||||
payment_method_id=payment_method_id)
|
||||
|
||||
# First, verify the subscription exists and get its current status
|
||||
existing_subscription = await self.subscription_service.get_subscription_by_provider_id(subscription_id)
|
||||
|
||||
if not existing_subscription:
|
||||
raise SubscriptionNotFound(f"Subscription {subscription_id} not found")
|
||||
|
||||
# Update the subscription in Stripe with the verified payment method
|
||||
stripe_subscription = await self.payment_service.update_subscription_payment_method(
|
||||
subscription_id,
|
||||
payment_method_id
|
||||
)
|
||||
|
||||
# Update our local subscription record
|
||||
await self.subscription_service.update_subscription_status(
|
||||
existing_subscription.tenant_id,
|
||||
stripe_subscription.status,
|
||||
{
|
||||
'current_period_start': datetime.fromtimestamp(stripe_subscription.current_period_start),
|
||||
'current_period_end': datetime.fromtimestamp(stripe_subscription.current_period_end)
|
||||
}
|
||||
)
|
||||
|
||||
# Create a mock subscription object-like dict for compatibility
|
||||
class SubscriptionResult:
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
self.id = data.get('subscription_id')
|
||||
self.status = data.get('status')
|
||||
self.current_period_start = data.get('current_period_start')
|
||||
self.current_period_end = data.get('current_period_end')
|
||||
self.customer = data.get('customer_id')
|
||||
|
||||
return {
|
||||
'subscription': SubscriptionResult({
|
||||
'subscription_id': stripe_subscription.id,
|
||||
'status': stripe_subscription.status,
|
||||
'current_period_start': stripe_subscription.current_period_start,
|
||||
'current_period_end': stripe_subscription.current_period_end,
|
||||
'customer_id': customer_id
|
||||
}),
|
||||
'verification': {
|
||||
'verified': True,
|
||||
'customer_id': customer_id,
|
||||
'payment_method_id': payment_method_id
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update subscription with verified payment",
|
||||
subscription_id=subscription_id,
|
||||
customer_id=customer_id,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise SubscriptionUpdateFailed(f"Failed to update subscription: {str(e)}")
|
||||
|
||||
@@ -68,28 +68,46 @@ class SubscriptionService:
|
||||
'tenant_id': str(tenant_id),
|
||||
'subscription_id': subscription_id,
|
||||
'customer_id': customer_id,
|
||||
'plan_id': plan,
|
||||
'plan': plan,
|
||||
'status': status,
|
||||
'created_at': datetime.now(timezone.utc),
|
||||
'trial_period_days': trial_period_days,
|
||||
'billing_cycle': billing_interval
|
||||
}
|
||||
|
||||
created_subscription = await self.subscription_repo.create(subscription_data)
|
||||
# Add trial-related data if applicable
|
||||
if trial_period_days and trial_period_days > 0:
|
||||
from datetime import timedelta
|
||||
trial_ends_at = datetime.now(timezone.utc) + timedelta(days=trial_period_days)
|
||||
subscription_data['trial_ends_at'] = trial_ends_at
|
||||
|
||||
logger.info("subscription_record_created",
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
plan=plan)
|
||||
# Check if subscription with this subscription_id already exists to prevent duplicates
|
||||
existing_subscription = await self.subscription_repo.get_by_provider_id(subscription_id)
|
||||
if existing_subscription:
|
||||
# Update the existing subscription instead of creating a duplicate
|
||||
updated_subscription = await self.subscription_repo.update(str(existing_subscription.id), subscription_data)
|
||||
|
||||
return created_subscription
|
||||
logger.info("Existing subscription updated",
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
plan=plan)
|
||||
|
||||
return updated_subscription
|
||||
else:
|
||||
# Create new subscription
|
||||
created_subscription = await self.subscription_repo.create(subscription_data)
|
||||
|
||||
logger.info("subscription_record_created",
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
plan=plan)
|
||||
|
||||
return created_subscription
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error("create_subscription_record_validation_failed",
|
||||
error=str(ve), tenant_id=tenant_id)
|
||||
logger.error(f"create_subscription_record_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error("create_subscription_record_failed", error=str(e), tenant_id=tenant_id)
|
||||
logger.error(f"create_subscription_record_failed, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise DatabaseError(f"Failed to create subscription record: {str(e)}")
|
||||
|
||||
async def update_subscription_status(
|
||||
@@ -126,15 +144,15 @@ class SubscriptionService:
|
||||
|
||||
# Include Stripe data if provided
|
||||
if stripe_data:
|
||||
if 'current_period_start' in stripe_data:
|
||||
update_data['current_period_start'] = stripe_data['current_period_start']
|
||||
if 'current_period_end' in stripe_data:
|
||||
update_data['current_period_end'] = stripe_data['current_period_end']
|
||||
# Note: current_period_start and current_period_end are not in the local model
|
||||
# These would need to be stored separately or handled differently
|
||||
# For now, we'll skip storing these Stripe-specific fields in the local model
|
||||
pass
|
||||
|
||||
# Update status flags based on status value
|
||||
if status == 'active':
|
||||
update_data['is_active'] = True
|
||||
update_data['canceled_at'] = None
|
||||
update_data['cancelled_at'] = None
|
||||
elif status in ['canceled', 'past_due', 'unpaid', 'inactive']:
|
||||
update_data['is_active'] = False
|
||||
elif status == 'pending_cancellation':
|
||||
@@ -153,11 +171,10 @@ class SubscriptionService:
|
||||
return updated_subscription
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error("update_subscription_status_validation_failed",
|
||||
error=str(ve), tenant_id=tenant_id)
|
||||
logger.error(f"update_subscription_status_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error("update_subscription_status_failed", error=str(e), tenant_id=tenant_id)
|
||||
logger.error(f"update_subscription_status_failed, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise DatabaseError(f"Failed to update subscription status: {str(e)}")
|
||||
|
||||
async def get_subscription_by_tenant_id(
|
||||
@@ -201,6 +218,23 @@ class SubscriptionService:
|
||||
error=str(e), subscription_id=subscription_id)
|
||||
return None
|
||||
|
||||
async def get_subscriptions_by_customer_id(self, customer_id: str) -> List[Subscription]:
|
||||
"""
|
||||
Get subscriptions by Stripe customer ID
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
|
||||
Returns:
|
||||
List of Subscription objects
|
||||
"""
|
||||
try:
|
||||
return await self.subscription_repo.get_by_customer_id(customer_id)
|
||||
except Exception as e:
|
||||
logger.error("get_subscriptions_by_customer_id_failed",
|
||||
error=str(e), customer_id=customer_id)
|
||||
return []
|
||||
|
||||
async def cancel_subscription(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -264,11 +298,10 @@ class SubscriptionService:
|
||||
}
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error("subscription_cancellation_validation_failed",
|
||||
error=str(ve), tenant_id=tenant_id)
|
||||
logger.error(f"subscription_cancellation_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error("subscription_cancellation_failed", error=str(e), tenant_id=tenant_id)
|
||||
logger.error(f"subscription_cancellation_failed, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise DatabaseError(f"Failed to cancel subscription: {str(e)}")
|
||||
|
||||
async def reactivate_subscription(
|
||||
@@ -301,7 +334,7 @@ class SubscriptionService:
|
||||
# Update subscription status and plan
|
||||
update_data = {
|
||||
'status': 'active',
|
||||
'plan_id': plan,
|
||||
'plan': plan,
|
||||
'cancelled_at': None,
|
||||
'cancellation_effective_date': None
|
||||
}
|
||||
@@ -329,11 +362,10 @@ class SubscriptionService:
|
||||
}
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error("subscription_reactivation_validation_failed",
|
||||
error=str(ve), tenant_id=tenant_id)
|
||||
logger.error(f"subscription_reactivation_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error("subscription_reactivation_failed", error=str(e), tenant_id=tenant_id)
|
||||
logger.error(f"subscription_reactivation_failed, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise DatabaseError(f"Failed to reactivate subscription: {str(e)}")
|
||||
|
||||
async def get_subscription_status(
|
||||
@@ -378,7 +410,7 @@ class SubscriptionService:
|
||||
error=str(ve), tenant_id=tenant_id)
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error("get_subscription_status_failed", error=str(e), tenant_id=tenant_id)
|
||||
logger.error(f"get_subscription_status_failed, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise DatabaseError(f"Failed to get subscription status: {str(e)}")
|
||||
|
||||
async def update_subscription_plan_record(
|
||||
@@ -416,13 +448,14 @@ class SubscriptionService:
|
||||
|
||||
# Update local subscription record
|
||||
update_data = {
|
||||
'plan_id': new_plan,
|
||||
'plan': new_plan,
|
||||
'status': new_status,
|
||||
'current_period_start': new_period_start,
|
||||
'current_period_end': new_period_end,
|
||||
'updated_at': datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
# Note: current_period_start and current_period_end are not in the local model
|
||||
# These Stripe-specific fields would need to be stored separately
|
||||
|
||||
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
|
||||
|
||||
# Invalidate subscription cache
|
||||
@@ -447,11 +480,10 @@ class SubscriptionService:
|
||||
}
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error("update_subscription_plan_record_validation_failed",
|
||||
error=str(ve), tenant_id=tenant_id)
|
||||
logger.error(f"update_subscription_plan_record_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error("update_subscription_plan_record_failed", error=str(e), tenant_id=tenant_id)
|
||||
logger.error(f"update_subscription_plan_record_failed, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise DatabaseError(f"Failed to update subscription plan record: {str(e)}")
|
||||
|
||||
async def update_billing_cycle_record(
|
||||
@@ -490,11 +522,12 @@ class SubscriptionService:
|
||||
# Update local subscription record
|
||||
update_data = {
|
||||
'status': new_status,
|
||||
'current_period_start': new_period_start,
|
||||
'current_period_end': new_period_end,
|
||||
'updated_at': datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
# Note: current_period_start and current_period_end are not in the local model
|
||||
# These Stripe-specific fields would need to be stored separately
|
||||
|
||||
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
|
||||
|
||||
# Invalidate subscription cache
|
||||
@@ -521,11 +554,10 @@ class SubscriptionService:
|
||||
}
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error("change_billing_cycle_validation_failed",
|
||||
error=str(ve), tenant_id=tenant_id)
|
||||
logger.error(f"change_billing_cycle_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error("change_billing_cycle_failed", error=str(e), tenant_id=tenant_id)
|
||||
logger.error(f"change_billing_cycle_failed, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise DatabaseError(f"Failed to change billing cycle: {str(e)}")
|
||||
|
||||
async def _invalidate_cache(self, tenant_id: str):
|
||||
@@ -620,13 +652,18 @@ class SubscriptionService:
|
||||
'plan': plan, # Repository expects 'plan', not 'plan_id'
|
||||
'status': status,
|
||||
'created_at': datetime.now(timezone.utc),
|
||||
'trial_period_days': trial_period_days,
|
||||
'billing_cycle': billing_interval,
|
||||
'user_id': user_id,
|
||||
'is_tenant_linked': False,
|
||||
'tenant_linking_status': 'pending'
|
||||
}
|
||||
|
||||
# Add trial-related data if applicable
|
||||
if trial_period_days and trial_period_days > 0:
|
||||
from datetime import timedelta
|
||||
trial_ends_at = datetime.now(timezone.utc) + timedelta(days=trial_period_days)
|
||||
subscription_data['trial_ends_at'] = trial_ends_at
|
||||
|
||||
created_subscription = await self.subscription_repo.create_tenant_independent_subscription(subscription_data)
|
||||
|
||||
logger.info("tenant_independent_subscription_record_created",
|
||||
@@ -650,7 +687,7 @@ class SubscriptionService:
|
||||
try:
|
||||
return await self.subscription_repo.get_pending_tenant_linking_subscriptions()
|
||||
except Exception as e:
|
||||
logger.error("Failed to get pending tenant linking subscriptions", error=str(e))
|
||||
logger.error(f"Failed to get pending tenant linking subscriptions: {str(e)}")
|
||||
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
|
||||
|
||||
async def get_pending_subscriptions_by_user(self, user_id: str) -> List[Subscription]:
|
||||
@@ -699,5 +736,57 @@ class SubscriptionService:
|
||||
try:
|
||||
return await self.subscription_repo.cleanup_orphaned_subscriptions(days_old)
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup orphaned subscriptions", error=str(e))
|
||||
logger.error(f"Failed to cleanup orphaned subscriptions: {str(e)}")
|
||||
raise DatabaseError(f"Failed to cleanup orphaned subscriptions: {str(e)}")
|
||||
|
||||
async def update_subscription_info(
|
||||
self,
|
||||
subscription_id: str,
|
||||
update_data: Dict[str, Any]
|
||||
) -> Subscription:
|
||||
"""
|
||||
Update subscription-related information (3DS flags, status, etc.)
|
||||
|
||||
This is useful for updating tenant-independent subscriptions during registration.
|
||||
|
||||
Args:
|
||||
subscription_id: Subscription ID
|
||||
update_data: Dictionary with fields to update
|
||||
|
||||
Returns:
|
||||
Updated Subscription object
|
||||
"""
|
||||
try:
|
||||
# Filter allowed fields
|
||||
allowed_fields = {
|
||||
'plan', 'status', 'is_tenant_linked', 'tenant_linking_status',
|
||||
'threeds_authentication_required', 'threeds_authentication_required_at',
|
||||
'threeds_authentication_completed', 'threeds_authentication_completed_at',
|
||||
'last_threeds_setup_intent_id', 'threeds_action_type'
|
||||
}
|
||||
|
||||
filtered_data = {k: v for k, v in update_data.items() if k in allowed_fields}
|
||||
|
||||
if not filtered_data:
|
||||
logger.warning("No valid subscription info fields provided for update",
|
||||
subscription_id=subscription_id)
|
||||
return await self.subscription_repo.get_by_id(subscription_id)
|
||||
|
||||
updated_subscription = await self.subscription_repo.update(subscription_id, filtered_data)
|
||||
|
||||
if not updated_subscription:
|
||||
raise ValidationError(f"Subscription not found: {subscription_id}")
|
||||
|
||||
logger.info("Subscription info updated",
|
||||
subscription_id=subscription_id,
|
||||
updated_fields=list(filtered_data.keys()))
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update subscription info",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update subscription info: {str(e)}")
|
||||
|
||||
@@ -12,7 +12,7 @@ from fastapi import HTTPException, status
|
||||
from app.repositories import TenantRepository, TenantMemberRepository, SubscriptionRepository
|
||||
from app.models.tenants import Tenant, TenantMember, Subscription
|
||||
from app.schemas.tenants import (
|
||||
BakeryRegistration, TenantResponse, TenantAccessResponse,
|
||||
BakeryRegistration, TenantResponse, TenantAccessResponse,
|
||||
TenantUpdate, TenantMemberResponse
|
||||
)
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
@@ -25,11 +25,11 @@ logger = structlog.get_logger()
|
||||
|
||||
class EnhancedTenantService:
|
||||
"""Enhanced tenant management business logic using repository pattern with dependency injection"""
|
||||
|
||||
|
||||
def __init__(self, database_manager=None, event_publisher=None):
|
||||
self.database_manager = database_manager or create_database_manager()
|
||||
self.event_publisher = event_publisher
|
||||
|
||||
|
||||
async def _init_repositories(self, session):
|
||||
"""Initialize repositories with session"""
|
||||
self.tenant_repo = TenantRepository(Tenant, session)
|
||||
@@ -40,15 +40,15 @@ class EnhancedTenantService:
|
||||
'member': self.member_repo,
|
||||
'subscription': self.subscription_repo
|
||||
}
|
||||
|
||||
|
||||
async def create_bakery(
|
||||
self,
|
||||
bakery_data: BakeryRegistration,
|
||||
self,
|
||||
bakery_data: BakeryRegistration,
|
||||
owner_id: str,
|
||||
session=None
|
||||
) -> TenantResponse:
|
||||
"""Create a new bakery/tenant with enhanced validation and features using repository pattern"""
|
||||
|
||||
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
async with UnitOfWork(db_session) as uow:
|
||||
@@ -115,10 +115,10 @@ class EnhancedTenantService:
|
||||
"longitude": longitude,
|
||||
"is_active": True
|
||||
}
|
||||
|
||||
|
||||
# Create tenant using repository
|
||||
tenant = await tenant_repo.create_tenant(tenant_data)
|
||||
|
||||
|
||||
# Create owner membership
|
||||
membership_data = {
|
||||
"tenant_id": str(tenant.id),
|
||||
@@ -126,7 +126,7 @@ class EnhancedTenantService:
|
||||
"role": "owner",
|
||||
"is_active": True
|
||||
}
|
||||
|
||||
|
||||
owner_membership = await member_repo.create_membership(membership_data)
|
||||
|
||||
# Get subscription plan from user's registration using standardized auth client
|
||||
@@ -164,10 +164,10 @@ class EnhancedTenantService:
|
||||
logger.info("Subscription created",
|
||||
tenant_id=tenant.id,
|
||||
plan=selected_plan)
|
||||
|
||||
|
||||
# Commit the transaction
|
||||
await uow.commit()
|
||||
|
||||
|
||||
# Publish tenant created event
|
||||
if self.event_publisher:
|
||||
try:
|
||||
@@ -245,7 +245,7 @@ class EnhancedTenantService:
|
||||
subdomain=tenant.subdomain)
|
||||
|
||||
return TenantResponse.from_orm(tenant)
|
||||
|
||||
|
||||
except (ValidationError, DuplicateRecordError) as e:
|
||||
logger.error("Validation error creating bakery",
|
||||
name=bakery_data.name,
|
||||
@@ -264,19 +264,19 @@ class EnhancedTenantService:
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create bakery"
|
||||
)
|
||||
|
||||
|
||||
async def verify_user_access(
|
||||
self,
|
||||
user_id: str,
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
) -> TenantAccessResponse:
|
||||
"""Verify if user has access to tenant with enhanced permissions"""
|
||||
|
||||
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
access_info = await self.member_repo.verify_user_access(user_id, tenant_id)
|
||||
|
||||
|
||||
return TenantAccessResponse(
|
||||
has_access=access_info["has_access"],
|
||||
role=access_info["role"],
|
||||
@@ -284,7 +284,7 @@ class EnhancedTenantService:
|
||||
membership_id=access_info.get("membership_id"),
|
||||
joined_at=access_info.get("joined_at")
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error verifying user access",
|
||||
user_id=user_id,
|
||||
@@ -295,10 +295,10 @@ class EnhancedTenantService:
|
||||
role="none",
|
||||
permissions=[]
|
||||
)
|
||||
|
||||
|
||||
async def get_tenant_by_id(self, tenant_id: str) -> Optional[TenantResponse]:
|
||||
"""Get tenant by ID with enhanced data"""
|
||||
|
||||
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
@@ -306,16 +306,16 @@ class EnhancedTenantService:
|
||||
if tenant:
|
||||
return TenantResponse.from_orm(tenant)
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
|
||||
async def get_tenant_by_subdomain(self, subdomain: str) -> Optional[TenantResponse]:
|
||||
"""Get tenant by subdomain"""
|
||||
|
||||
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
@@ -323,40 +323,40 @@ class EnhancedTenantService:
|
||||
if tenant:
|
||||
return TenantResponse.from_orm(tenant)
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting tenant by subdomain",
|
||||
subdomain=subdomain,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
|
||||
async def get_user_tenants(self, user_id: str) -> List[TenantResponse]:
|
||||
"""Get all tenants accessible by a user (both owned and member tenants)"""
|
||||
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
|
||||
# Get tenants where user is the owner
|
||||
owned_tenants = await self.tenant_repo.get_tenants_by_owner(user_id)
|
||||
|
||||
|
||||
# Get tenants where user is a member (but not owner)
|
||||
memberships = await self.member_repo.get_user_memberships(user_id, active_only=True)
|
||||
|
||||
|
||||
# Get tenant details for each membership
|
||||
member_tenant_ids = [str(membership.tenant_id) for membership in memberships]
|
||||
member_tenants = []
|
||||
|
||||
|
||||
if member_tenant_ids:
|
||||
# Get tenant details for each membership
|
||||
for tenant_id in member_tenant_ids:
|
||||
tenant = await self.tenant_repo.get_by_id(tenant_id)
|
||||
if tenant:
|
||||
member_tenants.append(tenant)
|
||||
|
||||
|
||||
# Combine and deduplicate (in case user is both owner and member)
|
||||
all_tenants = owned_tenants + member_tenants
|
||||
|
||||
|
||||
# Remove duplicates by tenant ID
|
||||
unique_tenants = []
|
||||
seen_ids = set()
|
||||
@@ -364,7 +364,7 @@ class EnhancedTenantService:
|
||||
if str(tenant.id) not in seen_ids:
|
||||
seen_ids.add(str(tenant.id))
|
||||
unique_tenants.append(tenant)
|
||||
|
||||
|
||||
logger.info(
|
||||
"Retrieved user tenants",
|
||||
user_id=user_id,
|
||||
@@ -372,7 +372,7 @@ class EnhancedTenantService:
|
||||
member_count=len(member_tenants),
|
||||
total_count=len(unique_tenants)
|
||||
)
|
||||
|
||||
|
||||
return [TenantResponse.from_orm(tenant) for tenant in unique_tenants]
|
||||
|
||||
except Exception as e:
|
||||
@@ -510,7 +510,7 @@ class EnhancedTenantService:
|
||||
limit=limit,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
|
||||
async def search_tenants(
|
||||
self,
|
||||
search_term: str,
|
||||
@@ -520,7 +520,7 @@ class EnhancedTenantService:
|
||||
limit: int = 50
|
||||
) -> List[TenantResponse]:
|
||||
"""Search tenants with filters"""
|
||||
|
||||
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
@@ -528,13 +528,13 @@ class EnhancedTenantService:
|
||||
search_term, business_type, city, skip, limit
|
||||
)
|
||||
return [TenantResponse.from_orm(tenant) for tenant in tenants]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error searching tenants",
|
||||
search_term=search_term,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
|
||||
async def update_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -590,17 +590,17 @@ class EnhancedTenantService:
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update tenant"
|
||||
)
|
||||
|
||||
|
||||
async def add_team_member(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
role: str,
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
role: str,
|
||||
invited_by: str,
|
||||
session: AsyncSession = None
|
||||
) -> TenantMemberResponse:
|
||||
"""Add a team member to tenant with enhanced validation"""
|
||||
|
||||
|
||||
try:
|
||||
# Verify inviter has admin access
|
||||
access = await self.verify_user_access(invited_by, tenant_id)
|
||||
@@ -609,11 +609,11 @@ class EnhancedTenantService:
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions to add team members"
|
||||
)
|
||||
|
||||
|
||||
# Create membership using repository
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
|
||||
membership_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
@@ -621,9 +621,9 @@ class EnhancedTenantService:
|
||||
"invited_by": invited_by,
|
||||
"is_active": True
|
||||
}
|
||||
|
||||
|
||||
member = await self.member_repo.create_membership(membership_data)
|
||||
|
||||
|
||||
# Publish member added event
|
||||
if self.event_publisher:
|
||||
try:
|
||||
@@ -640,15 +640,15 @@ class EnhancedTenantService:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish member added event", error=str(e))
|
||||
|
||||
|
||||
logger.info("Team member added successfully",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
invited_by=invited_by)
|
||||
|
||||
|
||||
return TenantMemberResponse.from_orm(member)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except (ValidationError, DuplicateRecordError) as e:
|
||||
@@ -669,7 +669,7 @@ class EnhancedTenantService:
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to add team member"
|
||||
)
|
||||
|
||||
|
||||
async def get_team_members(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -697,7 +697,7 @@ class EnhancedTenantService:
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
|
||||
async def update_member_role(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -707,7 +707,7 @@ class EnhancedTenantService:
|
||||
session: AsyncSession = None
|
||||
) -> TenantMemberResponse:
|
||||
"""Update team member role"""
|
||||
|
||||
|
||||
try:
|
||||
# Verify updater has admin access
|
||||
access = await self.verify_user_access(updated_by, tenant_id)
|
||||
@@ -716,19 +716,19 @@ class EnhancedTenantService:
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions to update member roles"
|
||||
)
|
||||
|
||||
|
||||
updated_member = await self.member_repo.update_member_role(
|
||||
tenant_id, member_user_id, new_role, updated_by
|
||||
)
|
||||
|
||||
|
||||
if not updated_member:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Member not found"
|
||||
)
|
||||
|
||||
|
||||
return TenantMemberResponse.from_orm(updated_member)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except (ValidationError, DuplicateRecordError) as e:
|
||||
@@ -745,7 +745,7 @@ class EnhancedTenantService:
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update member role"
|
||||
)
|
||||
|
||||
|
||||
async def remove_team_member(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -754,7 +754,7 @@ class EnhancedTenantService:
|
||||
session: AsyncSession = None
|
||||
) -> bool:
|
||||
"""Remove team member from tenant"""
|
||||
|
||||
|
||||
try:
|
||||
# Verify remover has admin access
|
||||
access = await self.verify_user_access(removed_by, tenant_id)
|
||||
@@ -763,19 +763,19 @@ class EnhancedTenantService:
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Insufficient permissions to remove team members"
|
||||
)
|
||||
|
||||
|
||||
removed_member = await self.member_repo.deactivate_membership(
|
||||
tenant_id, member_user_id, removed_by
|
||||
)
|
||||
|
||||
|
||||
if not removed_member:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Member not found"
|
||||
)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValidationError as e:
|
||||
@@ -798,10 +798,10 @@ class EnhancedTenantService:
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
|
||||
# Get all user memberships
|
||||
memberships = await self.member_repo.get_user_memberships(user_id, active_only=False)
|
||||
|
||||
|
||||
# Convert to response format
|
||||
result = []
|
||||
for membership in memberships:
|
||||
@@ -814,13 +814,13 @@ class EnhancedTenantService:
|
||||
"joined_at": membership.joined_at.isoformat() if membership.joined_at else None,
|
||||
"invited_by": str(membership.invited_by) if membership.invited_by else None
|
||||
})
|
||||
|
||||
|
||||
logger.info("Retrieved user memberships",
|
||||
user_id=user_id,
|
||||
membership_count=len(result))
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user memberships",
|
||||
user_id=user_id,
|
||||
@@ -829,7 +829,7 @@ class EnhancedTenantService:
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get user memberships"
|
||||
)
|
||||
|
||||
|
||||
async def update_model_status(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -838,7 +838,7 @@ class EnhancedTenantService:
|
||||
last_training_date: datetime = None
|
||||
) -> TenantResponse:
|
||||
"""Update tenant model training status"""
|
||||
|
||||
|
||||
try:
|
||||
# Verify user has access
|
||||
access = await self.verify_user_access(user_id, tenant_id)
|
||||
@@ -847,21 +847,21 @@ class EnhancedTenantService:
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant"
|
||||
)
|
||||
|
||||
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
updated_tenant = await self.tenant_repo.update_tenant_model_status(
|
||||
tenant_id, ml_model_trained, last_training_date
|
||||
)
|
||||
|
||||
|
||||
if not updated_tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tenant not found"
|
||||
)
|
||||
|
||||
|
||||
return TenantResponse.from_orm(updated_tenant)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -872,31 +872,31 @@ class EnhancedTenantService:
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update model status"
|
||||
)
|
||||
|
||||
|
||||
async def get_tenant_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive tenant statistics"""
|
||||
|
||||
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
# Get tenant statistics
|
||||
tenant_stats = await self.tenant_repo.get_tenant_statistics()
|
||||
|
||||
|
||||
# Get subscription statistics
|
||||
subscription_stats = await self.subscription_repo.get_subscription_statistics()
|
||||
|
||||
|
||||
return {
|
||||
"tenants": tenant_stats,
|
||||
"subscriptions": subscription_stats
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting tenant statistics", error=str(e))
|
||||
logger.error(f"Error getting tenant statistics: {str(e)}")
|
||||
return {
|
||||
"tenants": {},
|
||||
"subscriptions": {}
|
||||
}
|
||||
|
||||
|
||||
async def get_tenants_near_location(
|
||||
self,
|
||||
latitude: float,
|
||||
@@ -905,23 +905,23 @@ class EnhancedTenantService:
|
||||
limit: int = 50
|
||||
) -> List[TenantResponse]:
|
||||
"""Get tenants near a geographic location"""
|
||||
|
||||
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
tenants = await self.tenant_repo.get_tenants_by_location(
|
||||
latitude, longitude, radius_km, limit
|
||||
)
|
||||
|
||||
|
||||
return [TenantResponse.from_orm(tenant) for tenant in tenants]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting tenants by location",
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
|
||||
async def deactivate_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -929,7 +929,7 @@ class EnhancedTenantService:
|
||||
session: AsyncSession = None
|
||||
) -> bool:
|
||||
"""Deactivate a tenant (admin only)"""
|
||||
|
||||
|
||||
try:
|
||||
# Verify user is owner
|
||||
access = await self.verify_user_access(user_id, tenant_id)
|
||||
@@ -938,29 +938,29 @@ class EnhancedTenantService:
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only tenant owner can deactivate tenant"
|
||||
)
|
||||
|
||||
|
||||
deactivated_tenant = await self.tenant_repo.deactivate_tenant(tenant_id)
|
||||
|
||||
|
||||
if not deactivated_tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tenant not found"
|
||||
)
|
||||
|
||||
|
||||
# Also suspend subscription
|
||||
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
|
||||
if subscription:
|
||||
await self.subscription_repo.suspend_subscription(
|
||||
str(subscription.id),
|
||||
str(subscription.id),
|
||||
"Tenant deactivated"
|
||||
)
|
||||
|
||||
|
||||
logger.info("Tenant deactivated",
|
||||
tenant_id=tenant_id,
|
||||
deactivated_by=user_id)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -971,7 +971,7 @@ class EnhancedTenantService:
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to deactivate tenant"
|
||||
)
|
||||
|
||||
|
||||
async def activate_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -1378,7 +1378,7 @@ class EnhancedTenantService:
|
||||
# ========================================================================
|
||||
# TENANT-INDEPENDENT SUBSCRIPTION METHODS (New Architecture)
|
||||
# ========================================================================
|
||||
|
||||
|
||||
async def link_subscription_to_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -1387,15 +1387,15 @@ class EnhancedTenantService:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Link a pending subscription to a tenant
|
||||
|
||||
|
||||
This completes the registration flow by associating the subscription
|
||||
created during registration with the tenant created during onboarding
|
||||
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID to link to
|
||||
subscription_id: Subscription ID to link
|
||||
user_id: User ID performing the linking (for validation)
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with linking results
|
||||
"""
|
||||
@@ -1409,30 +1409,30 @@ class EnhancedTenantService:
|
||||
tenant_repo = uow.register_repository(
|
||||
"tenants", TenantRepository, Tenant
|
||||
)
|
||||
|
||||
|
||||
# Get the subscription
|
||||
subscription = await subscription_repo.get_by_id(subscription_id)
|
||||
|
||||
|
||||
if not subscription:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Subscription not found"
|
||||
)
|
||||
|
||||
|
||||
# Verify subscription is in pending_tenant_linking state
|
||||
if subscription.tenant_linking_status != "pending":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Subscription is not in pending tenant linking state"
|
||||
)
|
||||
|
||||
|
||||
# Verify subscription belongs to this user
|
||||
if subscription.user_id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Subscription does not belong to this user"
|
||||
)
|
||||
|
||||
|
||||
# Update subscription with tenant_id
|
||||
update_data = {
|
||||
"tenant_id": tenant_id,
|
||||
@@ -1440,36 +1440,41 @@ class EnhancedTenantService:
|
||||
"tenant_linking_status": "completed",
|
||||
"linked_at": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
|
||||
await subscription_repo.update(subscription_id, update_data)
|
||||
|
||||
# Update tenant with subscription information
|
||||
|
||||
# Update tenant with subscription information including 3DS flags
|
||||
tenant_update = {
|
||||
"customer_id": subscription.customer_id,
|
||||
"subscription_status": subscription.status,
|
||||
"subscription_plan": subscription.plan,
|
||||
"subscription_tier": subscription.plan,
|
||||
"billing_cycle": subscription.billing_cycle,
|
||||
"trial_period_days": subscription.trial_period_days
|
||||
"trial_period_days": subscription.trial_period_days,
|
||||
"threeds_authentication_required": getattr(subscription, 'threeds_authentication_required', False),
|
||||
"threeds_authentication_required_at": getattr(subscription, 'threeds_authentication_required_at', None),
|
||||
"threeds_authentication_completed": getattr(subscription, 'threeds_authentication_completed', False),
|
||||
"threeds_authentication_completed_at": getattr(subscription, 'threeds_authentication_completed_at', None),
|
||||
"last_threeds_setup_intent_id": getattr(subscription, 'last_threeds_setup_intent_id', None),
|
||||
"threeds_action_type": getattr(subscription, 'threeds_action_type', None)
|
||||
}
|
||||
|
||||
await tenant_repo.update_tenant(tenant_id, tenant_update)
|
||||
|
||||
|
||||
await tenant_repo.update(tenant_id, tenant_update)
|
||||
|
||||
# Commit transaction
|
||||
await uow.commit()
|
||||
|
||||
|
||||
logger.info("Subscription successfully linked to tenant",
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
user_id=user_id)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tenant_id": tenant_id,
|
||||
"subscription_id": subscription_id,
|
||||
"status": "linked"
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to link subscription to tenant",
|
||||
error=str(e),
|
||||
@@ -1478,5 +1483,127 @@ class EnhancedTenantService:
|
||||
user_id=user_id)
|
||||
raise
|
||||
|
||||
async def update_tenant_subscription_info(
|
||||
self,
|
||||
tenant_id: str,
|
||||
update_data: Dict[str, Any]
|
||||
) -> TenantResponse:
|
||||
"""
|
||||
Update tenant subscription-related information (plan, status, 3DS flags)
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
update_data: Dictionary with fields to update
|
||||
|
||||
Returns:
|
||||
Updated Tenant object
|
||||
"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
# Filter allowed fields to prevent accidental overwrites of core tenant data
|
||||
allowed_fields = {
|
||||
'subscription_plan', 'subscription_status', 'subscription_tier',
|
||||
'billing_cycle', 'trial_period_days', 'customer_id',
|
||||
'threeds_authentication_required', 'threeds_authentication_required_at',
|
||||
'threeds_authentication_completed', 'threeds_authentication_completed_at',
|
||||
'last_threeds_setup_intent_id', 'threeds_action_type'
|
||||
}
|
||||
|
||||
filtered_data = {k: v for k, v in update_data.items() if k in allowed_fields}
|
||||
|
||||
if not filtered_data:
|
||||
logger.warning("No valid subscription info fields provided for update",
|
||||
tenant_id=tenant_id)
|
||||
tenant = await self.tenant_repo.get_by_id(tenant_id)
|
||||
return TenantResponse.from_orm(tenant)
|
||||
|
||||
updated_tenant = await self.tenant_repo.update(tenant_id, filtered_data)
|
||||
|
||||
if not updated_tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tenant not found"
|
||||
)
|
||||
|
||||
logger.info("Tenant subscription info updated",
|
||||
tenant_id=tenant_id,
|
||||
updated_fields=list(filtered_data.keys()))
|
||||
|
||||
return TenantResponse.from_orm(updated_tenant)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update tenant subscription info",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to update tenant subscription info: {str(e)}"
|
||||
)
|
||||
|
||||
async def get_tenant_by_customer_id(self, customer_id: str) -> Optional[Tenant]:
|
||||
"""
|
||||
Get tenant by Stripe customer ID
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
|
||||
Returns:
|
||||
Tenant object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
# Search for tenant with matching customer_id
|
||||
tenant = await self.tenant_repo.get_by_customer_id(customer_id)
|
||||
|
||||
if tenant:
|
||||
logger.info("Found tenant by customer_id",
|
||||
customer_id=customer_id,
|
||||
tenant_id=str(tenant.id))
|
||||
return tenant
|
||||
else:
|
||||
logger.info("No tenant found for customer_id",
|
||||
customer_id=customer_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting tenant by customer_id",
|
||||
customer_id=customer_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def get_subscriptions_by_customer_id(self, customer_id: str) -> List[Subscription]:
|
||||
"""
|
||||
Get subscriptions by Stripe customer ID
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
|
||||
Returns:
|
||||
List of Subscription objects
|
||||
"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
# Search for subscriptions with matching customer_id
|
||||
subscriptions = await self.subscription_repo.get_by_customer_id(customer_id)
|
||||
|
||||
logger.info("Found subscriptions by customer_id",
|
||||
customer_id=customer_id,
|
||||
count=len(subscriptions))
|
||||
return subscriptions
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting subscriptions by customer_id",
|
||||
customer_id=customer_id,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
# Legacy compatibility alias
|
||||
TenantService = EnhancedTenantService
|
||||
|
||||
@@ -97,7 +97,7 @@ class TenantSettingsService:
|
||||
|
||||
return settings
|
||||
except Exception as e:
|
||||
logger.error("Failed to get or create tenant settings", tenant_id=tenant_id, error=str(e), exc_info=True)
|
||||
logger.error(f"Failed to get or create tenant settings, tenant_id={tenant_id}, error={str(e)}", exc_info=True)
|
||||
# Re-raise as HTTPException to match the expected behavior
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
||||
Reference in New Issue
Block a user