Add subcription feature 3

This commit is contained in:
Urtzi Alfaro
2026-01-15 20:45:49 +01:00
parent a4c3b7da3f
commit b674708a4c
83 changed files with 9451 additions and 6828 deletions

View File

@@ -211,7 +211,8 @@ async def clone_demo_data(
subscription_data.get('cancellation_effective_date'),
session_time,
"cancellation_effective_date"
)
),
is_tenant_linked=True # Required for check constraint when tenant_id is set
)
db.add(subscription)
@@ -245,7 +246,8 @@ async def clone_demo_data(
"api_access": True,
"priority_support": True
},
next_billing_date=datetime.now(timezone.utc) + timedelta(days=90)
next_billing_date=datetime.now(timezone.utc) + timedelta(days=90),
is_tenant_linked=True # Required for check constraint when tenant_id is set
)
db.add(subscription)
@@ -323,7 +325,8 @@ async def clone_demo_data(
max_users=-1, # Unlimited for demo
max_locations=max_locations,
max_products=-1, # Unlimited for demo
features={}
features={},
is_tenant_linked=True # Required for check constraint when tenant_id is set
)
db.add(demo_subscription)
@@ -699,7 +702,8 @@ async def create_child_outlet(
max_users=10, # Demo limits
max_locations=1, # Single location for outlet
max_products=200,
features={}
features={},
is_tenant_linked=True # Required for check constraint when tenant_id is set
)
db.add(child_subscription)

File diff suppressed because it is too large Load Diff

View File

@@ -1122,6 +1122,88 @@ async def upgrade_subscription_plan(
detail="Failed to upgrade subscription plan"
)
# ============================================================================
# REGISTRATION ORCHESTRATION (SECURE 3DS FLOW)
# ============================================================================
@router.post("/registration-payment-setup")
async def registration_payment_setup(
user_data: Dict[str, Any],
payment_service: PaymentService = Depends(get_payment_service)
):
"""
Orchestrate initial payment setup for a new registration.
This creates the customer and SetupIntent for 3DS verification.
"""
try:
logger.info("Orchestrating registration payment setup",
email=user_data.get('email'))
result = await payment_service.complete_registration_payment_flow(user_data)
return result
except Exception as e:
logger.error("Registration payment setup orchestration failed",
email=user_data.get('email'),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Registration payment setup failed: {str(e)}"
)
@router.post("/verify-and-complete-registration")
async def verify_and_complete_registration(
registration_data: Dict[str, Any],
payment_service: PaymentService = Depends(get_payment_service),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""
Final step: Verify SetupIntent and create subscription + tenant.
Called after frontend confirms 3DS with Stripe.
"""
try:
setup_intent_id = registration_data.get('setup_intent_id')
user_data = registration_data.get('user_data', {})
logger.info("Verifying and completing registration",
email=user_data.get('email'),
setup_intent_id=setup_intent_id)
# 1. Complete subscription creation after verification
payment_result = await payment_service.complete_subscription_after_verification(
setup_intent_id, user_data
)
# 2. Create the bakery/tenant record
# Note: In a real flow, we'd use the payment result (customer_id, subscription_id)
# to properly link the new tenant.
bakery_registration = BakeryRegistration(
name=user_data.get('name', f"{user_data.get('full_name')}'s Bakery"),
subdomain=user_data.get('subdomain', user_data.get('email').split('@')[0]),
business_type=user_data.get('business_type', 'bakery'),
link_existing_subscription=True,
subscription_id=payment_result['subscription_id']
)
# We need the user_id from the auth service call
user_id = user_data.get('user_id')
if not user_id:
raise HTTPException(status_code=400, detail="Missing user_id in registration data")
tenant_result = await tenant_service.create_bakery(bakery_registration, user_id)
return {
"success": True,
"tenant": tenant_result,
"payment": payment_result
}
except Exception as e:
logger.error("Verify and complete registration failed",
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Registration completion failed: {str(e)}"
)
# ============================================================================
# PAYMENT OPERATIONS
# ============================================================================
@@ -1244,7 +1326,7 @@ async def register_with_subscription(
**result
}
except Exception as e:
logger.error("Failed to register with subscription", error=str(e))
logger.error(f"Failed to register with subscription: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to register with subscription"

View File

@@ -88,7 +88,10 @@ async def stripe_webhook(
raise
except Exception as e:
logger.error("Error processing Stripe webhook", error=str(e), exc_info=True)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Webhook processing error"
)
# Return 200 OK even on processing errors to prevent Stripe retries
# Only return 4xx for signature verification failures
return {
"success": False,
"error": "Webhook processing error",
"details": str(e)
}

View File

@@ -67,6 +67,14 @@ class Tenant(Base):
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
# 3D Secure (3DS) tracking
threeds_authentication_required = Column(Boolean, default=False)
threeds_authentication_required_at = Column(DateTime(timezone=True), nullable=True)
threeds_authentication_completed = Column(Boolean, default=False)
threeds_authentication_completed_at = Column(DateTime(timezone=True), nullable=True)
last_threeds_setup_intent_id = Column(String(255), nullable=True)
threeds_action_type = Column(String(100), nullable=True)
# Relationships - only within tenant service
members = relationship("TenantMember", back_populates="tenant", cascade="all, delete-orphan")
subscriptions = relationship("Subscription", back_populates="tenant", cascade="all, delete-orphan")
@@ -187,6 +195,14 @@ class Subscription(Base):
# Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
# 3D Secure (3DS) tracking
threeds_authentication_required = Column(Boolean, default=False)
threeds_authentication_required_at = Column(DateTime(timezone=True), nullable=True)
threeds_authentication_completed = Column(Boolean, default=False)
threeds_authentication_completed_at = Column(DateTime(timezone=True), nullable=True)
last_threeds_setup_intent_id = Column(String(255), nullable=True)
threeds_action_type = Column(String(100), nullable=True)
# Relationships
tenant = relationship("Tenant")

View File

@@ -82,15 +82,29 @@ class SubscriptionRepository(TenantBaseRepository):
else:
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
# Check if subscription with this subscription_id already exists to prevent duplicates
if subscription_data.get('subscription_id'):
existing_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
if existing_subscription:
# Update the existing subscription instead of creating a duplicate
updated_subscription = await self.update(str(existing_subscription.id), subscription_data)
logger.info("Existing subscription updated",
subscription_id=subscription_data['subscription_id'],
tenant_id=subscription_data.get('tenant_id'),
plan=subscription_data.get('plan'))
return updated_subscription
# Create subscription
subscription = await self.create(subscription_data)
logger.info("Subscription created successfully",
subscription_id=subscription.id,
tenant_id=subscription.tenant_id,
plan=subscription.plan,
monthly_price=subscription.monthly_price)
return subscription
except (ValidationError, DuplicateRecordError):
@@ -514,7 +528,8 @@ class SubscriptionRepository(TenantBaseRepository):
"""Create a subscription not linked to any tenant (for registration flow)"""
try:
# Validate required data for tenant-independent subscription
required_fields = ["user_id", "plan", "subscription_id", "customer_id"]
# user_id may not exist during registration, so validate other required fields
required_fields = ["plan", "subscription_id", "customer_id"]
validation_result = self._validate_tenant_data(subscription_data, required_fields)
if not validation_result["is_valid"]:
@@ -567,16 +582,41 @@ class SubscriptionRepository(TenantBaseRepository):
else:
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
# Create tenant-independent subscription
subscription = await self.create(subscription_data)
logger.info("Tenant-independent subscription created successfully",
subscription_id=subscription.id,
user_id=subscription.user_id,
plan=subscription.plan,
monthly_price=subscription.monthly_price)
return subscription
# Check if subscription with this subscription_id already exists
existing_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
if existing_subscription:
# Update the existing subscription instead of creating a duplicate
updated_subscription = await self.update(str(existing_subscription.id), subscription_data)
logger.info("Existing tenant-independent subscription updated",
subscription_id=subscription_data['subscription_id'],
user_id=subscription_data.get('user_id'),
plan=subscription_data.get('plan'))
return updated_subscription
else:
# Create new subscription, but handle potential duplicate errors
try:
subscription = await self.create(subscription_data)
logger.info("Tenant-independent subscription created successfully",
subscription_id=subscription.id,
user_id=subscription.user_id,
plan=subscription.plan,
monthly_price=subscription.monthly_price)
return subscription
except DuplicateRecordError:
# Another process may have created the subscription between our check and create
# Try to get the existing subscription and return it
final_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
if final_subscription:
logger.info("Race condition detected: subscription already created by another process",
subscription_id=subscription_data['subscription_id'])
return final_subscription
else:
# This shouldn't happen, but re-raise the error if we can't find it
raise
except (ValidationError, DuplicateRecordError):
raise
@@ -700,3 +740,29 @@ class SubscriptionRepository(TenantBaseRepository):
logger.error("Failed to cleanup orphaned subscriptions",
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def get_by_customer_id(self, customer_id: str) -> List[Subscription]:
"""
Get subscriptions by Stripe customer ID
Args:
customer_id: Stripe customer ID
Returns:
List of Subscription objects
"""
try:
query = select(Subscription).where(Subscription.customer_id == customer_id)
result = await self.session.execute(query)
subscriptions = result.scalars().all()
logger.debug("Found subscriptions by customer_id",
customer_id=customer_id,
count=len(subscriptions))
return subscriptions
except Exception as e:
logger.error("Error getting subscriptions by customer_id",
customer_id=customer_id,
error=str(e))
raise DatabaseError(f"Failed to get subscriptions by customer_id: {str(e)}")

View File

@@ -11,7 +11,7 @@ import structlog
import uuid
from .base import TenantBaseRepository
from app.models.tenants import Tenant
from app.models.tenants import Tenant, Subscription
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
logger = structlog.get_logger()
@@ -570,3 +570,39 @@ class TenantRepository(TenantBaseRepository):
session_id=session_id,
error=str(e))
raise DatabaseError(f"Failed to get enterprise demo tenants: {str(e)}")
async def get_by_customer_id(self, customer_id: str) -> Optional[Tenant]:
"""
Get tenant by Stripe customer ID
Args:
customer_id: Stripe customer ID
Returns:
Tenant object if found, None otherwise
"""
try:
# Find tenant by joining with subscriptions table
# Tenant doesn't have customer_id directly, so we need to find via subscription
query = select(Tenant).join(
Subscription, Subscription.tenant_id == Tenant.id
).where(Subscription.customer_id == customer_id)
result = await self.session.execute(query)
tenant = result.scalar_one_or_none()
if tenant:
logger.debug("Found tenant by customer_id",
customer_id=customer_id,
tenant_id=str(tenant.id))
return tenant
else:
logger.debug("No tenant found for customer_id",
customer_id=customer_id)
return None
except Exception as e:
logger.error("Error getting tenant by customer_id",
customer_id=customer_id,
error=str(e))
raise DatabaseError(f"Failed to get tenant by customer_id: {str(e)}")

View File

@@ -43,7 +43,7 @@ class NetworkAlertsService:
return enriched_children
except Exception as e:
logger.error("Failed to get child tenants", parent_id=parent_id, error=str(e))
logger.error(f"Failed to get child tenants, parent_id={parent_id}, error={str(e)}")
raise Exception(f"Failed to get child tenants: {str(e)}")
async def get_alerts_for_tenant(self, tenant_id: str) -> List[Dict[str, Any]]:
@@ -80,7 +80,7 @@ class NetworkAlertsService:
return simulated_alerts
except Exception as e:
logger.error("Failed to get alerts for tenant", tenant_id=tenant_id, error=str(e))
logger.error(f"Failed to get alerts for tenant, tenant_id={tenant_id}, error={str(e)}")
raise Exception(f"Failed to get alerts: {str(e)}")
async def get_network_alerts(self, parent_id: str) -> List[Dict[str, Any]]:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,358 @@
"""
Registration State Management Service
Tracks registration progress and handles state transitions
"""
import structlog
from typing import Dict, Any, Optional
from datetime import datetime
from enum import Enum
from uuid import uuid4
from shared.exceptions.registration_exceptions import (
RegistrationStateError,
InvalidStateTransitionError
)
# Configure logging
logger = structlog.get_logger()
class RegistrationState(Enum):
"""Registration process states"""
INITIATED = "initiated"
PAYMENT_VERIFICATION_PENDING = "payment_verification_pending"
PAYMENT_VERIFIED = "payment_verified"
SUBSCRIPTION_CREATED = "subscription_created"
USER_CREATED = "user_created"
COMPLETED = "completed"
FAILED = "failed"
class RegistrationStateService:
"""
Registration State Management Service
Tracks and manages registration process state
"""
def __init__(self):
"""Initialize state service"""
# In production, this would use a database
self.registration_states = {}
async def create_registration_state(
self,
email: str,
user_data: Dict[str, Any]
) -> str:
"""
Create new registration state
Args:
email: User email
user_data: Registration data
Returns:
Registration state ID
"""
try:
state_id = str(uuid4())
registration_state = {
'state_id': state_id,
'email': email,
'current_state': RegistrationState.INITIATED.value,
'created_at': datetime.now().isoformat(),
'updated_at': datetime.now().isoformat(),
'user_data': user_data,
'setup_intent_id': None,
'customer_id': None,
'subscription_id': None,
'error': None
}
self.registration_states[state_id] = registration_state
logger.info("Registration state created",
state_id=state_id,
email=email,
current_state=RegistrationState.INITIATED.value)
return state_id
except Exception as e:
logger.error("Failed to create registration state",
error=str(e),
email=email,
exc_info=True)
raise RegistrationStateError(f"State creation failed: {str(e)}") from e
async def transition_state(
self,
state_id: str,
new_state: RegistrationState,
context: Optional[Dict[str, Any]] = None
) -> None:
"""
Transition registration to new state with validation
Args:
state_id: Registration state ID
new_state: New state to transition to
context: Additional context data
Raises:
InvalidStateTransitionError: If transition is invalid
RegistrationStateError: If transition fails
"""
try:
if state_id not in self.registration_states:
raise RegistrationStateError(f"Registration state {state_id} not found")
current_state = self.registration_states[state_id]['current_state']
# Validate state transition
if not self._is_valid_transition(current_state, new_state.value):
raise InvalidStateTransitionError(
f"Invalid transition from {current_state} to {new_state.value}"
)
# Update state
self.registration_states[state_id]['current_state'] = new_state.value
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
# Update context data
if context:
self.registration_states[state_id].update(context)
logger.info("Registration state transitioned",
state_id=state_id,
from_state=current_state,
to_state=new_state.value)
except InvalidStateTransitionError:
raise
except Exception as e:
logger.error("State transition failed",
error=str(e),
state_id=state_id,
from_state=current_state,
to_state=new_state.value,
exc_info=True)
raise RegistrationStateError(f"State transition failed: {str(e)}") from e
async def update_state_context(
self,
state_id: str,
context: Dict[str, Any]
) -> None:
"""
Update state context data
Args:
state_id: Registration state ID
context: Context data to update
Raises:
RegistrationStateError: If update fails
"""
try:
if state_id not in self.registration_states:
raise RegistrationStateError(f"Registration state {state_id} not found")
self.registration_states[state_id].update(context)
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
logger.debug("Registration state context updated",
state_id=state_id,
context_keys=list(context.keys()))
except Exception as e:
logger.error("State context update failed",
error=str(e),
state_id=state_id,
exc_info=True)
raise RegistrationStateError(f"State context update failed: {str(e)}") from e
async def get_registration_state(
self,
state_id: str
) -> Dict[str, Any]:
"""
Get registration state by ID
Args:
state_id: Registration state ID
Returns:
Registration state data
Raises:
RegistrationStateError: If state not found
"""
try:
if state_id not in self.registration_states:
raise RegistrationStateError(f"Registration state {state_id} not found")
return self.registration_states[state_id]
except Exception as e:
logger.error("Failed to get registration state",
error=str(e),
state_id=state_id,
exc_info=True)
raise RegistrationStateError(f"State retrieval failed: {str(e)}") from e
async def rollback_state(
self,
state_id: str,
target_state: RegistrationState
) -> None:
"""
Rollback registration to previous state
Args:
state_id: Registration state ID
target_state: State to rollback to
Raises:
RegistrationStateError: If rollback fails
"""
try:
if state_id not in self.registration_states:
raise RegistrationStateError(f"Registration state {state_id} not found")
current_state = self.registration_states[state_id]['current_state']
# Only allow rollback to earlier states
if not self._can_rollback(current_state, target_state.value):
raise InvalidStateTransitionError(
f"Cannot rollback from {current_state} to {target_state.value}"
)
# Update state
self.registration_states[state_id]['current_state'] = target_state.value
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
self.registration_states[state_id]['error'] = "Registration rolled back"
logger.warning("Registration state rolled back",
state_id=state_id,
from_state=current_state,
to_state=target_state.value)
except InvalidStateTransitionError:
raise
except Exception as e:
logger.error("State rollback failed",
error=str(e),
state_id=state_id,
from_state=current_state,
to_state=target_state.value,
exc_info=True)
raise RegistrationStateError(f"State rollback failed: {str(e)}") from e
async def mark_registration_failed(
self,
state_id: str,
error: str
) -> None:
"""
Mark registration as failed
Args:
state_id: Registration state ID
error: Error message
Raises:
RegistrationStateError: If operation fails
"""
try:
if state_id not in self.registration_states:
raise RegistrationStateError(f"Registration state {state_id} not found")
self.registration_states[state_id]['current_state'] = RegistrationState.FAILED.value
self.registration_states[state_id]['error'] = error
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
logger.error("Registration marked as failed",
state_id=state_id,
error=error)
except Exception as e:
logger.error("Failed to mark registration as failed",
error=str(e),
state_id=state_id,
exc_info=True)
raise RegistrationStateError(f"Mark failed operation failed: {str(e)}") from e
def _is_valid_transition(self, current_state: str, new_state: str) -> bool:
"""
Validate state transition
Args:
current_state: Current state
new_state: New state
Returns:
True if transition is valid
"""
# Define valid state transitions
valid_transitions = {
RegistrationState.INITIATED.value: [
RegistrationState.PAYMENT_VERIFICATION_PENDING.value,
RegistrationState.FAILED.value
],
RegistrationState.PAYMENT_VERIFICATION_PENDING.value: [
RegistrationState.PAYMENT_VERIFIED.value,
RegistrationState.FAILED.value
],
RegistrationState.PAYMENT_VERIFIED.value: [
RegistrationState.SUBSCRIPTION_CREATED.value,
RegistrationState.FAILED.value
],
RegistrationState.SUBSCRIPTION_CREATED.value: [
RegistrationState.USER_CREATED.value,
RegistrationState.FAILED.value
],
RegistrationState.USER_CREATED.value: [
RegistrationState.COMPLETED.value,
RegistrationState.FAILED.value
],
RegistrationState.COMPLETED.value: [],
RegistrationState.FAILED.value: []
}
return new_state in valid_transitions.get(current_state, [])
def _can_rollback(self, current_state: str, target_state: str) -> bool:
"""
Check if rollback to target state is allowed
Args:
current_state: Current state
target_state: Target state for rollback
Returns:
True if rollback is allowed
"""
# Define state order for rollback validation
state_order = [
RegistrationState.INITIATED.value,
RegistrationState.PAYMENT_VERIFICATION_PENDING.value,
RegistrationState.PAYMENT_VERIFIED.value,
RegistrationState.SUBSCRIPTION_CREATED.value,
RegistrationState.USER_CREATED.value,
RegistrationState.COMPLETED.value
]
try:
current_index = state_order.index(current_state)
target_index = state_order.index(target_state)
# Can only rollback to earlier states
return target_index < current_index
except ValueError:
return False
# Singleton instance for dependency injection
registration_state_service = RegistrationStateService()

View File

@@ -17,6 +17,8 @@ from app.services.tenant_service import EnhancedTenantService
from app.core.config import settings
from shared.database.exceptions import DatabaseError, ValidationError
from shared.database.base import create_database_manager
from shared.exceptions.payment_exceptions import SubscriptionUpdateFailed
from shared.exceptions.subscription_exceptions import SubscriptionNotFound
logger = structlog.get_logger()
@@ -69,7 +71,10 @@ class SubscriptionOrchestrationService:
logger.info("Creating customer in payment provider",
tenant_id=tenant_id, email=user_data.get('email'))
customer = await self.payment_service.create_customer(user_data)
email = user_data.get('email')
name = f"{user_data.get('first_name', '')} {user_data.get('last_name', '')}".strip()
metadata = None
customer = await self.payment_service.create_customer(email, name, metadata)
logger.info("Customer created successfully",
customer_id=customer.id, tenant_id=tenant_id)
@@ -106,9 +111,12 @@ class SubscriptionOrchestrationService:
plan_id=plan_id,
trial_period_days=trial_period_days)
stripe_subscription = await self.payment_service.create_payment_subscription(
# Get the Stripe price ID for this plan
price_id = self.payment_service._get_stripe_price_id(plan_id, billing_interval)
stripe_subscription = await self.payment_service.create_subscription_with_verified_payment(
customer.id,
plan_id,
price_id,
payment_method_id,
trial_period_days if trial_period_days > 0 else None,
billing_interval
@@ -232,7 +240,10 @@ class SubscriptionOrchestrationService:
user_id=user_data.get('user_id'),
email=user_data.get('email'))
customer = await self.payment_service.create_customer(user_data)
email = user_data.get('email')
name = f"{user_data.get('first_name', '')} {user_data.get('last_name', '')}".strip()
metadata = None
customer = await self.payment_service.create_customer(email, name, metadata)
logger.info("Customer created successfully",
customer_id=customer.id,
user_id=user_data.get('user_id'))
@@ -271,9 +282,12 @@ class SubscriptionOrchestrationService:
plan_id=plan_id,
trial_period_days=trial_period_days)
subscription_result = await self.payment_service.create_payment_subscription(
# Get the Stripe price ID for this plan
price_id = self.payment_service._get_stripe_price_id(plan_id, billing_interval)
subscription_result = await self.payment_service.create_subscription_with_verified_payment(
customer.id,
plan_id,
price_id,
payment_method_id,
trial_period_days if trial_period_days > 0 else None,
billing_interval
@@ -302,9 +316,25 @@ class SubscriptionOrchestrationService:
}
# Extract subscription object from result
# Result can be either a dict with 'subscription' key or the subscription object directly
if isinstance(subscription_result, dict) and 'subscription' in subscription_result:
stripe_subscription = subscription_result['subscription']
# Result can be either:
# 1. A dict with 'subscription' key containing an object
# 2. A dict with subscription fields directly (subscription_id, status, etc.)
# 3. A subscription object directly
if isinstance(subscription_result, dict):
if 'subscription' in subscription_result:
stripe_subscription = subscription_result['subscription']
elif 'subscription_id' in subscription_result:
# Create a simple object-like wrapper for dict results
class SubscriptionWrapper:
def __init__(self, data: dict):
self.id = data.get('subscription_id')
self.status = data.get('status')
self.current_period_start = data.get('current_period_start')
self.current_period_end = data.get('current_period_end')
self.customer = data.get('customer_id')
stripe_subscription = SubscriptionWrapper(subscription_result)
else:
stripe_subscription = subscription_result
else:
stripe_subscription = subscription_result
@@ -980,7 +1010,7 @@ class SubscriptionOrchestrationService:
status=status)
# Find tenant by subscription
subscription = await self.subscription_service.get_subscription_by_stripe_id(subscription_id)
subscription = await self.subscription_service.get_subscription_by_provider_id(subscription_id)
if subscription:
# Update subscription status
@@ -1014,7 +1044,7 @@ class SubscriptionOrchestrationService:
subscription_id=subscription_id)
# Find and update subscription
subscription = await self.subscription_service.get_subscription_by_stripe_id(subscription_id)
subscription = await self.subscription_service.get_subscription_by_provider_id(subscription_id)
if subscription:
# Cancel subscription in our system
@@ -1618,16 +1648,14 @@ class SubscriptionOrchestrationService:
email=user_data.get('email'))
# Create customer without user_id metadata
customer_data = {
'email': user_data.get('email'),
'name': user_data.get('full_name'),
'metadata': {
'registration_flow': 'pre_user_creation',
'timestamp': datetime.now(timezone.utc).isoformat()
}
email = user_data.get('email')
name = user_data.get('full_name')
metadata = {
'registration_flow': 'pre_user_creation',
'timestamp': datetime.now(timezone.utc).isoformat()
}
customer = await self.payment_service.create_customer(customer_data)
customer = await self.payment_service.create_customer(email, name, metadata)
logger.info("Payment customer created for registration",
customer_id=customer.id,
email=user_data.get('email'))
@@ -1665,9 +1693,12 @@ class SubscriptionOrchestrationService:
plan_id=plan_id,
payment_method_id=payment_method_id)
subscription_result = await self.payment_service.create_payment_subscription(
# Get the Stripe price ID for this plan
price_id = self.payment_service._get_stripe_price_id(plan_id, billing_interval)
subscription_result = await self.payment_service.create_subscription_with_verified_payment(
customer.id,
plan_id,
price_id,
payment_method_id,
trial_period_days if trial_period_days > 0 else None,
billing_interval
@@ -1678,14 +1709,19 @@ class SubscriptionOrchestrationService:
logger.info("Registration payment setup requires SetupIntent confirmation",
customer_id=customer.id,
action_type=subscription_result.get('action_type'),
setup_intent_id=subscription_result.get('setup_intent_id'))
setup_intent_id=subscription_result.get('setup_intent_id'),
subscription_id=subscription_result.get('subscription_id'))
# Return the SetupIntent data for frontend to handle 3DS
# Note: subscription_id is included because for trial subscriptions,
# the subscription is already created in 'trialing' status even though
# the SetupIntent requires 3DS verification for future payments
return {
"requires_action": True,
"action_type": subscription_result.get('action_type'),
"action_type": subscription_result.get('action_type') or 'use_stripe_sdk',
"client_secret": subscription_result.get('client_secret'),
"setup_intent_id": subscription_result.get('setup_intent_id'),
"subscription_id": subscription_result.get('subscription_id'),
"customer_id": customer.id,
"payment_customer_id": customer.id,
"plan_id": plan_id,
@@ -1764,3 +1800,154 @@ class SubscriptionOrchestrationService:
error=str(e),
exc_info=True)
raise
async def validate_plan_upgrade(
self,
tenant_id: str,
new_plan: str
) -> Dict[str, Any]:
"""
Validate if a tenant can upgrade to a new plan
Args:
tenant_id: Tenant ID
new_plan: New plan to validate upgrade to
Returns:
Dictionary with validation result
"""
try:
logger.info("Validating plan upgrade",
tenant_id=tenant_id,
new_plan=new_plan)
# Delegate to subscription service for validation
can_upgrade = await self.subscription_service.validate_subscription_change(
tenant_id,
new_plan
)
result = {
"can_upgrade": can_upgrade,
"tenant_id": tenant_id,
"current_plan": None, # Would need to fetch current plan if needed
"new_plan": new_plan
}
if not can_upgrade:
result["reason"] = "Subscription change not allowed based on current status"
logger.info("Plan upgrade validation completed",
tenant_id=tenant_id,
can_upgrade=can_upgrade)
return result
except Exception as e:
logger.error("Plan upgrade validation failed",
tenant_id=tenant_id,
new_plan=new_plan,
error=str(e),
exc_info=True)
raise DatabaseError(f"Failed to validate plan upgrade: {str(e)}")
async def get_subscriptions_by_customer_id(self, customer_id: str) -> List[Subscription]:
"""
Get all subscriptions for a given customer ID
Args:
customer_id: Stripe customer ID
Returns:
List of Subscription objects
"""
try:
return await self.subscription_service.get_subscriptions_by_customer_id(customer_id)
except Exception as e:
logger.error("Failed to get subscriptions by customer ID",
customer_id=customer_id,
error=str(e),
exc_info=True)
raise DatabaseError(f"Failed to get subscriptions: {str(e)}")
async def update_subscription_with_verified_payment(
self,
subscription_id: str,
customer_id: str,
payment_method_id: str,
trial_period_days: Optional[int] = None
) -> Dict[str, Any]:
"""
Update an existing subscription with a verified payment method
This is used when we already have a trial subscription and just need to
attach the verified payment method to it.
Args:
subscription_id: Stripe subscription ID
customer_id: Stripe customer ID
payment_method_id: Verified payment method ID
trial_period_days: Optional trial period (for validation)
Returns:
Dictionary with updated subscription details
"""
try:
logger.info("Updating existing subscription with verified payment method",
subscription_id=subscription_id,
customer_id=customer_id,
payment_method_id=payment_method_id)
# First, verify the subscription exists and get its current status
existing_subscription = await self.subscription_service.get_subscription_by_provider_id(subscription_id)
if not existing_subscription:
raise SubscriptionNotFound(f"Subscription {subscription_id} not found")
# Update the subscription in Stripe with the verified payment method
stripe_subscription = await self.payment_service.update_subscription_payment_method(
subscription_id,
payment_method_id
)
# Update our local subscription record
await self.subscription_service.update_subscription_status(
existing_subscription.tenant_id,
stripe_subscription.status,
{
'current_period_start': datetime.fromtimestamp(stripe_subscription.current_period_start),
'current_period_end': datetime.fromtimestamp(stripe_subscription.current_period_end)
}
)
# Create a mock subscription object-like dict for compatibility
class SubscriptionResult:
def __init__(self, data: Dict[str, Any]):
self.id = data.get('subscription_id')
self.status = data.get('status')
self.current_period_start = data.get('current_period_start')
self.current_period_end = data.get('current_period_end')
self.customer = data.get('customer_id')
return {
'subscription': SubscriptionResult({
'subscription_id': stripe_subscription.id,
'status': stripe_subscription.status,
'current_period_start': stripe_subscription.current_period_start,
'current_period_end': stripe_subscription.current_period_end,
'customer_id': customer_id
}),
'verification': {
'verified': True,
'customer_id': customer_id,
'payment_method_id': payment_method_id
}
}
except Exception as e:
logger.error("Failed to update subscription with verified payment",
subscription_id=subscription_id,
customer_id=customer_id,
error=str(e),
exc_info=True)
raise SubscriptionUpdateFailed(f"Failed to update subscription: {str(e)}")

View File

@@ -68,28 +68,46 @@ class SubscriptionService:
'tenant_id': str(tenant_id),
'subscription_id': subscription_id,
'customer_id': customer_id,
'plan_id': plan,
'plan': plan,
'status': status,
'created_at': datetime.now(timezone.utc),
'trial_period_days': trial_period_days,
'billing_cycle': billing_interval
}
created_subscription = await self.subscription_repo.create(subscription_data)
# Add trial-related data if applicable
if trial_period_days and trial_period_days > 0:
from datetime import timedelta
trial_ends_at = datetime.now(timezone.utc) + timedelta(days=trial_period_days)
subscription_data['trial_ends_at'] = trial_ends_at
logger.info("subscription_record_created",
tenant_id=tenant_id,
subscription_id=subscription_id,
plan=plan)
# Check if subscription with this subscription_id already exists to prevent duplicates
existing_subscription = await self.subscription_repo.get_by_provider_id(subscription_id)
if existing_subscription:
# Update the existing subscription instead of creating a duplicate
updated_subscription = await self.subscription_repo.update(str(existing_subscription.id), subscription_data)
return created_subscription
logger.info("Existing subscription updated",
tenant_id=tenant_id,
subscription_id=subscription_id,
plan=plan)
return updated_subscription
else:
# Create new subscription
created_subscription = await self.subscription_repo.create(subscription_data)
logger.info("subscription_record_created",
tenant_id=tenant_id,
subscription_id=subscription_id,
plan=plan)
return created_subscription
except ValidationError as ve:
logger.error("create_subscription_record_validation_failed",
error=str(ve), tenant_id=tenant_id)
logger.error(f"create_subscription_record_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error("create_subscription_record_failed", error=str(e), tenant_id=tenant_id)
logger.error(f"create_subscription_record_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to create subscription record: {str(e)}")
async def update_subscription_status(
@@ -126,15 +144,15 @@ class SubscriptionService:
# Include Stripe data if provided
if stripe_data:
if 'current_period_start' in stripe_data:
update_data['current_period_start'] = stripe_data['current_period_start']
if 'current_period_end' in stripe_data:
update_data['current_period_end'] = stripe_data['current_period_end']
# Note: current_period_start and current_period_end are not in the local model
# These would need to be stored separately or handled differently
# For now, we'll skip storing these Stripe-specific fields in the local model
pass
# Update status flags based on status value
if status == 'active':
update_data['is_active'] = True
update_data['canceled_at'] = None
update_data['cancelled_at'] = None
elif status in ['canceled', 'past_due', 'unpaid', 'inactive']:
update_data['is_active'] = False
elif status == 'pending_cancellation':
@@ -153,11 +171,10 @@ class SubscriptionService:
return updated_subscription
except ValidationError as ve:
logger.error("update_subscription_status_validation_failed",
error=str(ve), tenant_id=tenant_id)
logger.error(f"update_subscription_status_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error("update_subscription_status_failed", error=str(e), tenant_id=tenant_id)
logger.error(f"update_subscription_status_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to update subscription status: {str(e)}")
async def get_subscription_by_tenant_id(
@@ -201,6 +218,23 @@ class SubscriptionService:
error=str(e), subscription_id=subscription_id)
return None
async def get_subscriptions_by_customer_id(self, customer_id: str) -> List[Subscription]:
"""
Get subscriptions by Stripe customer ID
Args:
customer_id: Stripe customer ID
Returns:
List of Subscription objects
"""
try:
return await self.subscription_repo.get_by_customer_id(customer_id)
except Exception as e:
logger.error("get_subscriptions_by_customer_id_failed",
error=str(e), customer_id=customer_id)
return []
async def cancel_subscription(
self,
tenant_id: str,
@@ -264,11 +298,10 @@ class SubscriptionService:
}
except ValidationError as ve:
logger.error("subscription_cancellation_validation_failed",
error=str(ve), tenant_id=tenant_id)
logger.error(f"subscription_cancellation_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error("subscription_cancellation_failed", error=str(e), tenant_id=tenant_id)
logger.error(f"subscription_cancellation_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to cancel subscription: {str(e)}")
async def reactivate_subscription(
@@ -301,7 +334,7 @@ class SubscriptionService:
# Update subscription status and plan
update_data = {
'status': 'active',
'plan_id': plan,
'plan': plan,
'cancelled_at': None,
'cancellation_effective_date': None
}
@@ -329,11 +362,10 @@ class SubscriptionService:
}
except ValidationError as ve:
logger.error("subscription_reactivation_validation_failed",
error=str(ve), tenant_id=tenant_id)
logger.error(f"subscription_reactivation_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error("subscription_reactivation_failed", error=str(e), tenant_id=tenant_id)
logger.error(f"subscription_reactivation_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to reactivate subscription: {str(e)}")
async def get_subscription_status(
@@ -378,7 +410,7 @@ class SubscriptionService:
error=str(ve), tenant_id=tenant_id)
raise ve
except Exception as e:
logger.error("get_subscription_status_failed", error=str(e), tenant_id=tenant_id)
logger.error(f"get_subscription_status_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to get subscription status: {str(e)}")
async def update_subscription_plan_record(
@@ -416,13 +448,14 @@ class SubscriptionService:
# Update local subscription record
update_data = {
'plan_id': new_plan,
'plan': new_plan,
'status': new_status,
'current_period_start': new_period_start,
'current_period_end': new_period_end,
'updated_at': datetime.now(timezone.utc)
}
# Note: current_period_start and current_period_end are not in the local model
# These Stripe-specific fields would need to be stored separately
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
# Invalidate subscription cache
@@ -447,11 +480,10 @@ class SubscriptionService:
}
except ValidationError as ve:
logger.error("update_subscription_plan_record_validation_failed",
error=str(ve), tenant_id=tenant_id)
logger.error(f"update_subscription_plan_record_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error("update_subscription_plan_record_failed", error=str(e), tenant_id=tenant_id)
logger.error(f"update_subscription_plan_record_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to update subscription plan record: {str(e)}")
async def update_billing_cycle_record(
@@ -490,11 +522,12 @@ class SubscriptionService:
# Update local subscription record
update_data = {
'status': new_status,
'current_period_start': new_period_start,
'current_period_end': new_period_end,
'updated_at': datetime.now(timezone.utc)
}
# Note: current_period_start and current_period_end are not in the local model
# These Stripe-specific fields would need to be stored separately
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
# Invalidate subscription cache
@@ -521,11 +554,10 @@ class SubscriptionService:
}
except ValidationError as ve:
logger.error("change_billing_cycle_validation_failed",
error=str(ve), tenant_id=tenant_id)
logger.error(f"change_billing_cycle_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error("change_billing_cycle_failed", error=str(e), tenant_id=tenant_id)
logger.error(f"change_billing_cycle_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to change billing cycle: {str(e)}")
async def _invalidate_cache(self, tenant_id: str):
@@ -620,13 +652,18 @@ class SubscriptionService:
'plan': plan, # Repository expects 'plan', not 'plan_id'
'status': status,
'created_at': datetime.now(timezone.utc),
'trial_period_days': trial_period_days,
'billing_cycle': billing_interval,
'user_id': user_id,
'is_tenant_linked': False,
'tenant_linking_status': 'pending'
}
# Add trial-related data if applicable
if trial_period_days and trial_period_days > 0:
from datetime import timedelta
trial_ends_at = datetime.now(timezone.utc) + timedelta(days=trial_period_days)
subscription_data['trial_ends_at'] = trial_ends_at
created_subscription = await self.subscription_repo.create_tenant_independent_subscription(subscription_data)
logger.info("tenant_independent_subscription_record_created",
@@ -650,7 +687,7 @@ class SubscriptionService:
try:
return await self.subscription_repo.get_pending_tenant_linking_subscriptions()
except Exception as e:
logger.error("Failed to get pending tenant linking subscriptions", error=str(e))
logger.error(f"Failed to get pending tenant linking subscriptions: {str(e)}")
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
async def get_pending_subscriptions_by_user(self, user_id: str) -> List[Subscription]:
@@ -699,5 +736,57 @@ class SubscriptionService:
try:
return await self.subscription_repo.cleanup_orphaned_subscriptions(days_old)
except Exception as e:
logger.error("Failed to cleanup orphaned subscriptions", error=str(e))
logger.error(f"Failed to cleanup orphaned subscriptions: {str(e)}")
raise DatabaseError(f"Failed to cleanup orphaned subscriptions: {str(e)}")
async def update_subscription_info(
self,
subscription_id: str,
update_data: Dict[str, Any]
) -> Subscription:
"""
Update subscription-related information (3DS flags, status, etc.)
This is useful for updating tenant-independent subscriptions during registration.
Args:
subscription_id: Subscription ID
update_data: Dictionary with fields to update
Returns:
Updated Subscription object
"""
try:
# Filter allowed fields
allowed_fields = {
'plan', 'status', 'is_tenant_linked', 'tenant_linking_status',
'threeds_authentication_required', 'threeds_authentication_required_at',
'threeds_authentication_completed', 'threeds_authentication_completed_at',
'last_threeds_setup_intent_id', 'threeds_action_type'
}
filtered_data = {k: v for k, v in update_data.items() if k in allowed_fields}
if not filtered_data:
logger.warning("No valid subscription info fields provided for update",
subscription_id=subscription_id)
return await self.subscription_repo.get_by_id(subscription_id)
updated_subscription = await self.subscription_repo.update(subscription_id, filtered_data)
if not updated_subscription:
raise ValidationError(f"Subscription not found: {subscription_id}")
logger.info("Subscription info updated",
subscription_id=subscription_id,
updated_fields=list(filtered_data.keys()))
return updated_subscription
except ValidationError:
raise
except Exception as e:
logger.error("Failed to update subscription info",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to update subscription info: {str(e)}")

View File

@@ -12,7 +12,7 @@ from fastapi import HTTPException, status
from app.repositories import TenantRepository, TenantMemberRepository, SubscriptionRepository
from app.models.tenants import Tenant, TenantMember, Subscription
from app.schemas.tenants import (
BakeryRegistration, TenantResponse, TenantAccessResponse,
BakeryRegistration, TenantResponse, TenantAccessResponse,
TenantUpdate, TenantMemberResponse
)
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
@@ -25,11 +25,11 @@ logger = structlog.get_logger()
class EnhancedTenantService:
"""Enhanced tenant management business logic using repository pattern with dependency injection"""
def __init__(self, database_manager=None, event_publisher=None):
self.database_manager = database_manager or create_database_manager()
self.event_publisher = event_publisher
async def _init_repositories(self, session):
"""Initialize repositories with session"""
self.tenant_repo = TenantRepository(Tenant, session)
@@ -40,15 +40,15 @@ class EnhancedTenantService:
'member': self.member_repo,
'subscription': self.subscription_repo
}
async def create_bakery(
self,
bakery_data: BakeryRegistration,
self,
bakery_data: BakeryRegistration,
owner_id: str,
session=None
) -> TenantResponse:
"""Create a new bakery/tenant with enhanced validation and features using repository pattern"""
try:
async with self.database_manager.get_session() as db_session:
async with UnitOfWork(db_session) as uow:
@@ -115,10 +115,10 @@ class EnhancedTenantService:
"longitude": longitude,
"is_active": True
}
# Create tenant using repository
tenant = await tenant_repo.create_tenant(tenant_data)
# Create owner membership
membership_data = {
"tenant_id": str(tenant.id),
@@ -126,7 +126,7 @@ class EnhancedTenantService:
"role": "owner",
"is_active": True
}
owner_membership = await member_repo.create_membership(membership_data)
# Get subscription plan from user's registration using standardized auth client
@@ -164,10 +164,10 @@ class EnhancedTenantService:
logger.info("Subscription created",
tenant_id=tenant.id,
plan=selected_plan)
# Commit the transaction
await uow.commit()
# Publish tenant created event
if self.event_publisher:
try:
@@ -245,7 +245,7 @@ class EnhancedTenantService:
subdomain=tenant.subdomain)
return TenantResponse.from_orm(tenant)
except (ValidationError, DuplicateRecordError) as e:
logger.error("Validation error creating bakery",
name=bakery_data.name,
@@ -264,19 +264,19 @@ class EnhancedTenantService:
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create bakery"
)
async def verify_user_access(
self,
user_id: str,
self,
user_id: str,
tenant_id: str
) -> TenantAccessResponse:
"""Verify if user has access to tenant with enhanced permissions"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
access_info = await self.member_repo.verify_user_access(user_id, tenant_id)
return TenantAccessResponse(
has_access=access_info["has_access"],
role=access_info["role"],
@@ -284,7 +284,7 @@ class EnhancedTenantService:
membership_id=access_info.get("membership_id"),
joined_at=access_info.get("joined_at")
)
except Exception as e:
logger.error("Error verifying user access",
user_id=user_id,
@@ -295,10 +295,10 @@ class EnhancedTenantService:
role="none",
permissions=[]
)
async def get_tenant_by_id(self, tenant_id: str) -> Optional[TenantResponse]:
"""Get tenant by ID with enhanced data"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
@@ -306,16 +306,16 @@ class EnhancedTenantService:
if tenant:
return TenantResponse.from_orm(tenant)
return None
except Exception as e:
logger.error("Error getting tenant",
tenant_id=tenant_id,
error=str(e))
return None
async def get_tenant_by_subdomain(self, subdomain: str) -> Optional[TenantResponse]:
"""Get tenant by subdomain"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
@@ -323,40 +323,40 @@ class EnhancedTenantService:
if tenant:
return TenantResponse.from_orm(tenant)
return None
except Exception as e:
logger.error("Error getting tenant by subdomain",
subdomain=subdomain,
error=str(e))
return None
async def get_user_tenants(self, user_id: str) -> List[TenantResponse]:
"""Get all tenants accessible by a user (both owned and member tenants)"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
# Get tenants where user is the owner
owned_tenants = await self.tenant_repo.get_tenants_by_owner(user_id)
# Get tenants where user is a member (but not owner)
memberships = await self.member_repo.get_user_memberships(user_id, active_only=True)
# Get tenant details for each membership
member_tenant_ids = [str(membership.tenant_id) for membership in memberships]
member_tenants = []
if member_tenant_ids:
# Get tenant details for each membership
for tenant_id in member_tenant_ids:
tenant = await self.tenant_repo.get_by_id(tenant_id)
if tenant:
member_tenants.append(tenant)
# Combine and deduplicate (in case user is both owner and member)
all_tenants = owned_tenants + member_tenants
# Remove duplicates by tenant ID
unique_tenants = []
seen_ids = set()
@@ -364,7 +364,7 @@ class EnhancedTenantService:
if str(tenant.id) not in seen_ids:
seen_ids.add(str(tenant.id))
unique_tenants.append(tenant)
logger.info(
"Retrieved user tenants",
user_id=user_id,
@@ -372,7 +372,7 @@ class EnhancedTenantService:
member_count=len(member_tenants),
total_count=len(unique_tenants)
)
return [TenantResponse.from_orm(tenant) for tenant in unique_tenants]
except Exception as e:
@@ -510,7 +510,7 @@ class EnhancedTenantService:
limit=limit,
error=str(e))
return []
async def search_tenants(
self,
search_term: str,
@@ -520,7 +520,7 @@ class EnhancedTenantService:
limit: int = 50
) -> List[TenantResponse]:
"""Search tenants with filters"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
@@ -528,13 +528,13 @@ class EnhancedTenantService:
search_term, business_type, city, skip, limit
)
return [TenantResponse.from_orm(tenant) for tenant in tenants]
except Exception as e:
logger.error("Error searching tenants",
search_term=search_term,
error=str(e))
return []
async def update_tenant(
self,
tenant_id: str,
@@ -590,17 +590,17 @@ class EnhancedTenantService:
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update tenant"
)
async def add_team_member(
self,
tenant_id: str,
user_id: str,
role: str,
self,
tenant_id: str,
user_id: str,
role: str,
invited_by: str,
session: AsyncSession = None
) -> TenantMemberResponse:
"""Add a team member to tenant with enhanced validation"""
try:
# Verify inviter has admin access
access = await self.verify_user_access(invited_by, tenant_id)
@@ -609,11 +609,11 @@ class EnhancedTenantService:
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions to add team members"
)
# Create membership using repository
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
membership_data = {
"tenant_id": tenant_id,
"user_id": user_id,
@@ -621,9 +621,9 @@ class EnhancedTenantService:
"invited_by": invited_by,
"is_active": True
}
member = await self.member_repo.create_membership(membership_data)
# Publish member added event
if self.event_publisher:
try:
@@ -640,15 +640,15 @@ class EnhancedTenantService:
)
except Exception as e:
logger.warning("Failed to publish member added event", error=str(e))
logger.info("Team member added successfully",
tenant_id=tenant_id,
user_id=user_id,
role=role,
invited_by=invited_by)
return TenantMemberResponse.from_orm(member)
except HTTPException:
raise
except (ValidationError, DuplicateRecordError) as e:
@@ -669,7 +669,7 @@ class EnhancedTenantService:
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to add team member"
)
async def get_team_members(
self,
tenant_id: str,
@@ -697,7 +697,7 @@ class EnhancedTenantService:
user_id=user_id,
error=str(e))
return []
async def update_member_role(
self,
tenant_id: str,
@@ -707,7 +707,7 @@ class EnhancedTenantService:
session: AsyncSession = None
) -> TenantMemberResponse:
"""Update team member role"""
try:
# Verify updater has admin access
access = await self.verify_user_access(updated_by, tenant_id)
@@ -716,19 +716,19 @@ class EnhancedTenantService:
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions to update member roles"
)
updated_member = await self.member_repo.update_member_role(
tenant_id, member_user_id, new_role, updated_by
)
if not updated_member:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Member not found"
)
return TenantMemberResponse.from_orm(updated_member)
except HTTPException:
raise
except (ValidationError, DuplicateRecordError) as e:
@@ -745,7 +745,7 @@ class EnhancedTenantService:
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update member role"
)
async def remove_team_member(
self,
tenant_id: str,
@@ -754,7 +754,7 @@ class EnhancedTenantService:
session: AsyncSession = None
) -> bool:
"""Remove team member from tenant"""
try:
# Verify remover has admin access
access = await self.verify_user_access(removed_by, tenant_id)
@@ -763,19 +763,19 @@ class EnhancedTenantService:
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions to remove team members"
)
removed_member = await self.member_repo.deactivate_membership(
tenant_id, member_user_id, removed_by
)
if not removed_member:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Member not found"
)
return True
except HTTPException:
raise
except ValidationError as e:
@@ -798,10 +798,10 @@ class EnhancedTenantService:
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
# Get all user memberships
memberships = await self.member_repo.get_user_memberships(user_id, active_only=False)
# Convert to response format
result = []
for membership in memberships:
@@ -814,13 +814,13 @@ class EnhancedTenantService:
"joined_at": membership.joined_at.isoformat() if membership.joined_at else None,
"invited_by": str(membership.invited_by) if membership.invited_by else None
})
logger.info("Retrieved user memberships",
user_id=user_id,
membership_count=len(result))
return result
except Exception as e:
logger.error("Failed to get user memberships",
user_id=user_id,
@@ -829,7 +829,7 @@ class EnhancedTenantService:
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get user memberships"
)
async def update_model_status(
self,
tenant_id: str,
@@ -838,7 +838,7 @@ class EnhancedTenantService:
last_training_date: datetime = None
) -> TenantResponse:
"""Update tenant model training status"""
try:
# Verify user has access
access = await self.verify_user_access(user_id, tenant_id)
@@ -847,21 +847,21 @@ class EnhancedTenantService:
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant"
)
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
updated_tenant = await self.tenant_repo.update_tenant_model_status(
tenant_id, ml_model_trained, last_training_date
)
if not updated_tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tenant not found"
)
return TenantResponse.from_orm(updated_tenant)
except HTTPException:
raise
except Exception as e:
@@ -872,31 +872,31 @@ class EnhancedTenantService:
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update model status"
)
async def get_tenant_statistics(self) -> Dict[str, Any]:
"""Get comprehensive tenant statistics"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
# Get tenant statistics
tenant_stats = await self.tenant_repo.get_tenant_statistics()
# Get subscription statistics
subscription_stats = await self.subscription_repo.get_subscription_statistics()
return {
"tenants": tenant_stats,
"subscriptions": subscription_stats
}
except Exception as e:
logger.error("Error getting tenant statistics", error=str(e))
logger.error(f"Error getting tenant statistics: {str(e)}")
return {
"tenants": {},
"subscriptions": {}
}
async def get_tenants_near_location(
self,
latitude: float,
@@ -905,23 +905,23 @@ class EnhancedTenantService:
limit: int = 50
) -> List[TenantResponse]:
"""Get tenants near a geographic location"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
tenants = await self.tenant_repo.get_tenants_by_location(
latitude, longitude, radius_km, limit
)
return [TenantResponse.from_orm(tenant) for tenant in tenants]
except Exception as e:
logger.error("Error getting tenants by location",
latitude=latitude,
longitude=longitude,
error=str(e))
return []
async def deactivate_tenant(
self,
tenant_id: str,
@@ -929,7 +929,7 @@ class EnhancedTenantService:
session: AsyncSession = None
) -> bool:
"""Deactivate a tenant (admin only)"""
try:
# Verify user is owner
access = await self.verify_user_access(user_id, tenant_id)
@@ -938,29 +938,29 @@ class EnhancedTenantService:
status_code=status.HTTP_403_FORBIDDEN,
detail="Only tenant owner can deactivate tenant"
)
deactivated_tenant = await self.tenant_repo.deactivate_tenant(tenant_id)
if not deactivated_tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tenant not found"
)
# Also suspend subscription
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
if subscription:
await self.subscription_repo.suspend_subscription(
str(subscription.id),
str(subscription.id),
"Tenant deactivated"
)
logger.info("Tenant deactivated",
tenant_id=tenant_id,
deactivated_by=user_id)
return True
except HTTPException:
raise
except Exception as e:
@@ -971,7 +971,7 @@ class EnhancedTenantService:
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to deactivate tenant"
)
async def activate_tenant(
self,
tenant_id: str,
@@ -1378,7 +1378,7 @@ class EnhancedTenantService:
# ========================================================================
# TENANT-INDEPENDENT SUBSCRIPTION METHODS (New Architecture)
# ========================================================================
async def link_subscription_to_tenant(
self,
tenant_id: str,
@@ -1387,15 +1387,15 @@ class EnhancedTenantService:
) -> Dict[str, Any]:
"""
Link a pending subscription to a tenant
This completes the registration flow by associating the subscription
created during registration with the tenant created during onboarding
Args:
tenant_id: Tenant ID to link to
subscription_id: Subscription ID to link
user_id: User ID performing the linking (for validation)
Returns:
Dictionary with linking results
"""
@@ -1409,30 +1409,30 @@ class EnhancedTenantService:
tenant_repo = uow.register_repository(
"tenants", TenantRepository, Tenant
)
# Get the subscription
subscription = await subscription_repo.get_by_id(subscription_id)
if not subscription:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Subscription not found"
)
# Verify subscription is in pending_tenant_linking state
if subscription.tenant_linking_status != "pending":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Subscription is not in pending tenant linking state"
)
# Verify subscription belongs to this user
if subscription.user_id != user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Subscription does not belong to this user"
)
# Update subscription with tenant_id
update_data = {
"tenant_id": tenant_id,
@@ -1440,36 +1440,41 @@ class EnhancedTenantService:
"tenant_linking_status": "completed",
"linked_at": datetime.now(timezone.utc)
}
await subscription_repo.update(subscription_id, update_data)
# Update tenant with subscription information
# Update tenant with subscription information including 3DS flags
tenant_update = {
"customer_id": subscription.customer_id,
"subscription_status": subscription.status,
"subscription_plan": subscription.plan,
"subscription_tier": subscription.plan,
"billing_cycle": subscription.billing_cycle,
"trial_period_days": subscription.trial_period_days
"trial_period_days": subscription.trial_period_days,
"threeds_authentication_required": getattr(subscription, 'threeds_authentication_required', False),
"threeds_authentication_required_at": getattr(subscription, 'threeds_authentication_required_at', None),
"threeds_authentication_completed": getattr(subscription, 'threeds_authentication_completed', False),
"threeds_authentication_completed_at": getattr(subscription, 'threeds_authentication_completed_at', None),
"last_threeds_setup_intent_id": getattr(subscription, 'last_threeds_setup_intent_id', None),
"threeds_action_type": getattr(subscription, 'threeds_action_type', None)
}
await tenant_repo.update_tenant(tenant_id, tenant_update)
await tenant_repo.update(tenant_id, tenant_update)
# Commit transaction
await uow.commit()
logger.info("Subscription successfully linked to tenant",
tenant_id=tenant_id,
subscription_id=subscription_id,
user_id=user_id)
return {
"success": True,
"tenant_id": tenant_id,
"subscription_id": subscription_id,
"status": "linked"
}
except Exception as e:
logger.error("Failed to link subscription to tenant",
error=str(e),
@@ -1478,5 +1483,127 @@ class EnhancedTenantService:
user_id=user_id)
raise
async def update_tenant_subscription_info(
self,
tenant_id: str,
update_data: Dict[str, Any]
) -> TenantResponse:
"""
Update tenant subscription-related information (plan, status, 3DS flags)
Args:
tenant_id: Tenant ID
update_data: Dictionary with fields to update
Returns:
Updated Tenant object
"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
# Filter allowed fields to prevent accidental overwrites of core tenant data
allowed_fields = {
'subscription_plan', 'subscription_status', 'subscription_tier',
'billing_cycle', 'trial_period_days', 'customer_id',
'threeds_authentication_required', 'threeds_authentication_required_at',
'threeds_authentication_completed', 'threeds_authentication_completed_at',
'last_threeds_setup_intent_id', 'threeds_action_type'
}
filtered_data = {k: v for k, v in update_data.items() if k in allowed_fields}
if not filtered_data:
logger.warning("No valid subscription info fields provided for update",
tenant_id=tenant_id)
tenant = await self.tenant_repo.get_by_id(tenant_id)
return TenantResponse.from_orm(tenant)
updated_tenant = await self.tenant_repo.update(tenant_id, filtered_data)
if not updated_tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tenant not found"
)
logger.info("Tenant subscription info updated",
tenant_id=tenant_id,
updated_fields=list(filtered_data.keys()))
return TenantResponse.from_orm(updated_tenant)
except HTTPException:
raise
except Exception as e:
logger.error("Failed to update tenant subscription info",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update tenant subscription info: {str(e)}"
)
async def get_tenant_by_customer_id(self, customer_id: str) -> Optional[Tenant]:
"""
Get tenant by Stripe customer ID
Args:
customer_id: Stripe customer ID
Returns:
Tenant object if found, None otherwise
"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
# Search for tenant with matching customer_id
tenant = await self.tenant_repo.get_by_customer_id(customer_id)
if tenant:
logger.info("Found tenant by customer_id",
customer_id=customer_id,
tenant_id=str(tenant.id))
return tenant
else:
logger.info("No tenant found for customer_id",
customer_id=customer_id)
return None
except Exception as e:
logger.error("Error getting tenant by customer_id",
customer_id=customer_id,
error=str(e))
return None
async def get_subscriptions_by_customer_id(self, customer_id: str) -> List[Subscription]:
"""
Get subscriptions by Stripe customer ID
Args:
customer_id: Stripe customer ID
Returns:
List of Subscription objects
"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
# Search for subscriptions with matching customer_id
subscriptions = await self.subscription_repo.get_by_customer_id(customer_id)
logger.info("Found subscriptions by customer_id",
customer_id=customer_id,
count=len(subscriptions))
return subscriptions
except Exception as e:
logger.error("Error getting subscriptions by customer_id",
customer_id=customer_id,
error=str(e))
return []
# Legacy compatibility alias
TenantService = EnhancedTenantService

View File

@@ -97,7 +97,7 @@ class TenantSettingsService:
return settings
except Exception as e:
logger.error("Failed to get or create tenant settings", tenant_id=tenant_id, error=str(e), exc_info=True)
logger.error(f"Failed to get or create tenant settings, tenant_id={tenant_id}, error={str(e)}", exc_info=True)
# Re-raise as HTTPException to match the expected behavior
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,