Files
bakery-ia/services/tenant/app/services/subscription_service.py
2026-01-15 20:45:49 +01:00

793 lines
30 KiB
Python

"""
Subscription Service - State Manager
This service handles ONLY subscription database operations and state management
NO payment provider interactions, NO orchestration, NO coupon logic
"""
import structlog
from typing import Dict, Any, Optional, List
from datetime import datetime, timezone, timedelta
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.tenants import Subscription, Tenant
from app.repositories.subscription_repository import SubscriptionRepository
from app.core.config import settings
from shared.database.exceptions import DatabaseError, ValidationError
from shared.subscription.plans import PlanPricing, QuotaLimits, SubscriptionPlanMetadata
logger = structlog.get_logger()
class SubscriptionService:
"""Service for managing subscription state and database operations ONLY"""
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
self.subscription_repo = SubscriptionRepository(Subscription, db_session)
async def create_subscription_record(
self,
tenant_id: str,
subscription_id: str,
customer_id: str,
plan: str,
status: str,
trial_period_days: Optional[int] = None,
billing_interval: str = "monthly"
) -> Subscription:
"""
Create a local subscription record in the database
Args:
tenant_id: Tenant ID
subscription_id: Payment provider subscription ID
customer_id: Payment provider customer ID
plan: Subscription plan
status: Subscription status
trial_period_days: Optional trial period in days
billing_interval: Billing interval (monthly or yearly)
Returns:
Created Subscription object
"""
try:
tenant_uuid = UUID(tenant_id)
# Verify tenant exists
query = select(Tenant).where(Tenant.id == tenant_uuid)
result = await self.db_session.execute(query)
tenant = result.scalar_one_or_none()
if not tenant:
raise ValidationError(f"Tenant not found: {tenant_id}")
# Create local subscription record
subscription_data = {
'tenant_id': str(tenant_id),
'subscription_id': subscription_id,
'customer_id': customer_id,
'plan': plan,
'status': status,
'created_at': datetime.now(timezone.utc),
'billing_cycle': billing_interval
}
# Add trial-related data if applicable
if trial_period_days and trial_period_days > 0:
from datetime import timedelta
trial_ends_at = datetime.now(timezone.utc) + timedelta(days=trial_period_days)
subscription_data['trial_ends_at'] = trial_ends_at
# Check if subscription with this subscription_id already exists to prevent duplicates
existing_subscription = await self.subscription_repo.get_by_provider_id(subscription_id)
if existing_subscription:
# Update the existing subscription instead of creating a duplicate
updated_subscription = await self.subscription_repo.update(str(existing_subscription.id), subscription_data)
logger.info("Existing subscription updated",
tenant_id=tenant_id,
subscription_id=subscription_id,
plan=plan)
return updated_subscription
else:
# Create new subscription
created_subscription = await self.subscription_repo.create(subscription_data)
logger.info("subscription_record_created",
tenant_id=tenant_id,
subscription_id=subscription_id,
plan=plan)
return created_subscription
except ValidationError as ve:
logger.error(f"create_subscription_record_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error(f"create_subscription_record_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to create subscription record: {str(e)}")
async def update_subscription_status(
self,
tenant_id: str,
status: str,
stripe_data: Optional[Dict[str, Any]] = None
) -> Subscription:
"""
Update subscription status in database
Args:
tenant_id: Tenant ID
status: New subscription status
stripe_data: Optional data from Stripe to update
Returns:
Updated Subscription object
"""
try:
tenant_uuid = UUID(tenant_id)
# Get subscription from repository
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
if not subscription:
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
# Prepare update data
update_data = {
'status': status,
'updated_at': datetime.now(timezone.utc)
}
# Include Stripe data if provided
if stripe_data:
# Note: current_period_start and current_period_end are not in the local model
# These would need to be stored separately or handled differently
# For now, we'll skip storing these Stripe-specific fields in the local model
pass
# Update status flags based on status value
if status == 'active':
update_data['is_active'] = True
update_data['cancelled_at'] = None
elif status in ['canceled', 'past_due', 'unpaid', 'inactive']:
update_data['is_active'] = False
elif status == 'pending_cancellation':
update_data['is_active'] = True # Still active until effective date
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
# Invalidate subscription cache
await self._invalidate_cache(tenant_id)
logger.info("subscription_status_updated",
tenant_id=tenant_id,
old_status=subscription.status,
new_status=status)
return updated_subscription
except ValidationError as ve:
logger.error(f"update_subscription_status_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error(f"update_subscription_status_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to update subscription status: {str(e)}")
async def get_subscription_by_tenant_id(
self,
tenant_id: str
) -> Optional[Subscription]:
"""
Get subscription by tenant ID
Args:
tenant_id: Tenant ID
Returns:
Subscription object or None
"""
try:
tenant_uuid = UUID(tenant_id)
return await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
except Exception as e:
logger.error("get_subscription_by_tenant_id_failed",
error=str(e), tenant_id=tenant_id)
return None
async def get_subscription_by_provider_id(
self,
subscription_id: str
) -> Optional[Subscription]:
"""
Get subscription by payment provider subscription ID
Args:
subscription_id: Payment provider subscription ID
Returns:
Subscription object or None
"""
try:
return await self.subscription_repo.get_by_provider_id(subscription_id)
except Exception as e:
logger.error("get_subscription_by_provider_id_failed",
error=str(e), subscription_id=subscription_id)
return None
async def get_subscriptions_by_customer_id(self, customer_id: str) -> List[Subscription]:
"""
Get subscriptions by Stripe customer ID
Args:
customer_id: Stripe customer ID
Returns:
List of Subscription objects
"""
try:
return await self.subscription_repo.get_by_customer_id(customer_id)
except Exception as e:
logger.error("get_subscriptions_by_customer_id_failed",
error=str(e), customer_id=customer_id)
return []
async def cancel_subscription(
self,
tenant_id: str,
reason: str = ""
) -> Dict[str, Any]:
"""
Mark subscription as pending cancellation in database
Args:
tenant_id: Tenant ID to cancel subscription for
reason: Optional cancellation reason
Returns:
Dictionary with cancellation details
"""
try:
tenant_uuid = UUID(tenant_id)
# Get subscription from repository
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
if not subscription:
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
if subscription.status in ['pending_cancellation', 'inactive']:
raise ValidationError(f"Subscription is already {subscription.status}")
# Calculate cancellation effective date (end of billing period)
cancellation_effective_date = subscription.next_billing_date or (
datetime.now(timezone.utc) + timedelta(days=30)
)
# Update subscription status in database
update_data = {
'status': 'pending_cancellation',
'cancelled_at': datetime.now(timezone.utc),
'cancellation_effective_date': cancellation_effective_date
}
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
# Invalidate subscription cache
await self._invalidate_cache(tenant_id)
days_remaining = (cancellation_effective_date - datetime.now(timezone.utc)).days
logger.info(
"subscription_cancelled",
tenant_id=str(tenant_id),
effective_date=cancellation_effective_date.isoformat(),
reason=reason[:200] if reason else None
)
return {
"success": True,
"message": "Subscription cancelled successfully. You will have read-only access until the end of your billing period.",
"status": "pending_cancellation",
"cancellation_effective_date": cancellation_effective_date.isoformat(),
"days_remaining": days_remaining,
"read_only_mode_starts": cancellation_effective_date.isoformat()
}
except ValidationError as ve:
logger.error(f"subscription_cancellation_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error(f"subscription_cancellation_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to cancel subscription: {str(e)}")
async def reactivate_subscription(
self,
tenant_id: str,
plan: str = "starter"
) -> Dict[str, Any]:
"""
Reactivate a cancelled or inactive subscription
Args:
tenant_id: Tenant ID to reactivate subscription for
plan: Plan to reactivate with
Returns:
Dictionary with reactivation details
"""
try:
tenant_uuid = UUID(tenant_id)
# Get subscription from repository
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
if not subscription:
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
if subscription.status not in ['pending_cancellation', 'inactive']:
raise ValidationError(f"Cannot reactivate subscription with status: {subscription.status}")
# Update subscription status and plan
update_data = {
'status': 'active',
'plan': plan,
'cancelled_at': None,
'cancellation_effective_date': None
}
if subscription.status == 'inactive':
update_data['next_billing_date'] = datetime.now(timezone.utc) + timedelta(days=30)
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
# Invalidate subscription cache
await self._invalidate_cache(tenant_id)
logger.info(
"subscription_reactivated",
tenant_id=str(tenant_id),
new_plan=plan
)
return {
"success": True,
"message": "Subscription reactivated successfully",
"status": "active",
"plan": plan,
"next_billing_date": updated_subscription.next_billing_date.isoformat() if updated_subscription.next_billing_date else None
}
except ValidationError as ve:
logger.error(f"subscription_reactivation_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error(f"subscription_reactivation_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to reactivate subscription: {str(e)}")
async def get_subscription_status(
self,
tenant_id: str
) -> Dict[str, Any]:
"""
Get current subscription status including read-only mode info
Args:
tenant_id: Tenant ID to get status for
Returns:
Dictionary with subscription status details
"""
try:
tenant_uuid = UUID(tenant_id)
# Get subscription from repository
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
if not subscription:
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
is_read_only = subscription.status in ['pending_cancellation', 'inactive']
days_until_inactive = None
if subscription.status == 'pending_cancellation' and subscription.cancellation_effective_date:
days_until_inactive = (subscription.cancellation_effective_date - datetime.now(timezone.utc)).days
return {
"tenant_id": str(tenant_id),
"status": subscription.status,
"plan": subscription.plan,
"is_read_only": is_read_only,
"cancellation_effective_date": subscription.cancellation_effective_date.isoformat() if subscription.cancellation_effective_date else None,
"days_until_inactive": days_until_inactive
}
except ValidationError as ve:
logger.error("get_subscription_status_validation_failed",
error=str(ve), tenant_id=tenant_id)
raise ve
except Exception as e:
logger.error(f"get_subscription_status_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to get subscription status: {str(e)}")
async def update_subscription_plan_record(
self,
tenant_id: str,
new_plan: str,
new_status: str,
new_period_start: datetime,
new_period_end: datetime,
billing_cycle: str = "monthly",
proration_details: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Update local subscription plan record in database
Args:
tenant_id: Tenant ID
new_plan: New plan name
new_status: New subscription status
new_period_start: New period start date
new_period_end: New period end date
billing_cycle: Billing cycle for the new plan
proration_details: Proration details from payment provider
Returns:
Dictionary with update results
"""
try:
tenant_uuid = UUID(tenant_id)
# Get current subscription
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
if not subscription:
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
# Update local subscription record
update_data = {
'plan': new_plan,
'status': new_status,
'updated_at': datetime.now(timezone.utc)
}
# Note: current_period_start and current_period_end are not in the local model
# These Stripe-specific fields would need to be stored separately
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
# Invalidate subscription cache
await self._invalidate_cache(tenant_id)
logger.info(
"subscription_plan_record_updated",
tenant_id=str(tenant_id),
old_plan=subscription.plan,
new_plan=new_plan,
proration_amount=proration_details.get("net_amount", 0) if proration_details else 0
)
return {
"success": True,
"message": f"Subscription plan record updated to {new_plan}",
"old_plan": subscription.plan,
"new_plan": new_plan,
"proration_details": proration_details,
"new_status": new_status,
"new_period_end": new_period_end.isoformat()
}
except ValidationError as ve:
logger.error(f"update_subscription_plan_record_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error(f"update_subscription_plan_record_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to update subscription plan record: {str(e)}")
async def update_billing_cycle_record(
self,
tenant_id: str,
new_billing_cycle: str,
new_status: str,
new_period_start: datetime,
new_period_end: datetime,
current_plan: str,
proration_details: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Update local billing cycle record in database
Args:
tenant_id: Tenant ID
new_billing_cycle: New billing cycle ('monthly' or 'yearly')
new_status: New subscription status
new_period_start: New period start date
new_period_end: New period end date
current_plan: Current plan name
proration_details: Proration details from payment provider
Returns:
Dictionary with billing cycle update results
"""
try:
tenant_uuid = UUID(tenant_id)
# Get current subscription
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
if not subscription:
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
# Update local subscription record
update_data = {
'status': new_status,
'updated_at': datetime.now(timezone.utc)
}
# Note: current_period_start and current_period_end are not in the local model
# These Stripe-specific fields would need to be stored separately
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
# Invalidate subscription cache
await self._invalidate_cache(tenant_id)
old_billing_cycle = getattr(subscription, 'billing_cycle', 'monthly')
logger.info(
"subscription_billing_cycle_record_updated",
tenant_id=str(tenant_id),
old_billing_cycle=old_billing_cycle,
new_billing_cycle=new_billing_cycle,
proration_amount=proration_details.get("net_amount", 0) if proration_details else 0
)
return {
"success": True,
"message": f"Billing cycle record changed to {new_billing_cycle}",
"old_billing_cycle": old_billing_cycle,
"new_billing_cycle": new_billing_cycle,
"proration_details": proration_details,
"new_status": new_status,
"new_period_end": new_period_end.isoformat()
}
except ValidationError as ve:
logger.error(f"change_billing_cycle_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error(f"change_billing_cycle_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to change billing cycle: {str(e)}")
async def _invalidate_cache(self, tenant_id: str):
"""Helper method to invalidate subscription cache"""
try:
from app.services.subscription_cache import get_subscription_cache_service
import shared.redis_utils
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
cache_service = get_subscription_cache_service(redis_client)
await cache_service.invalidate_subscription_cache(str(tenant_id))
logger.info(
"Subscription cache invalidated",
tenant_id=str(tenant_id)
)
except Exception as cache_error:
logger.error(
"Failed to invalidate subscription cache",
tenant_id=str(tenant_id),
error=str(cache_error)
)
async def validate_subscription_change(
self,
tenant_id: str,
new_plan: str
) -> bool:
"""
Validate if a subscription change is allowed
Args:
tenant_id: Tenant ID
new_plan: New plan to validate
Returns:
True if change is allowed
"""
try:
subscription = await self.get_subscription_by_tenant_id(tenant_id)
if not subscription:
return False
# Can't change if already pending cancellation or inactive
if subscription.status in ['pending_cancellation', 'inactive']:
return False
return True
except Exception as e:
logger.error("validate_subscription_change_failed",
error=str(e), tenant_id=tenant_id)
return False
# ========================================================================
# TENANT-INDEPENDENT SUBSCRIPTION METHODS (New Architecture)
# ========================================================================
async def create_tenant_independent_subscription_record(
self,
subscription_id: str,
customer_id: str,
plan: str,
status: str,
trial_period_days: Optional[int] = None,
billing_interval: str = "monthly",
user_id: str = None
) -> Subscription:
"""
Create a tenant-independent subscription record in the database
This subscription is not linked to any tenant and will be linked during onboarding
Args:
subscription_id: Payment provider subscription ID
customer_id: Payment provider customer ID
plan: Subscription plan
status: Subscription status
trial_period_days: Optional trial period in days
billing_interval: Billing interval (monthly or yearly)
user_id: User ID who created this subscription
Returns:
Created Subscription object
"""
try:
# Create tenant-independent subscription record
subscription_data = {
'subscription_id': subscription_id,
'customer_id': customer_id,
'plan': plan, # Repository expects 'plan', not 'plan_id'
'status': status,
'created_at': datetime.now(timezone.utc),
'billing_cycle': billing_interval,
'user_id': user_id,
'is_tenant_linked': False,
'tenant_linking_status': 'pending'
}
# Add trial-related data if applicable
if trial_period_days and trial_period_days > 0:
from datetime import timedelta
trial_ends_at = datetime.now(timezone.utc) + timedelta(days=trial_period_days)
subscription_data['trial_ends_at'] = trial_ends_at
created_subscription = await self.subscription_repo.create_tenant_independent_subscription(subscription_data)
logger.info("tenant_independent_subscription_record_created",
subscription_id=subscription_id,
user_id=user_id,
plan=plan)
return created_subscription
except ValidationError as ve:
logger.error("create_tenant_independent_subscription_record_validation_failed",
error=str(ve), user_id=user_id)
raise ve
except Exception as e:
logger.error("create_tenant_independent_subscription_record_failed",
error=str(e), user_id=user_id)
raise DatabaseError(f"Failed to create tenant-independent subscription record: {str(e)}")
async def get_pending_tenant_linking_subscriptions(self) -> List[Subscription]:
"""Get all subscriptions waiting to be linked to tenants"""
try:
return await self.subscription_repo.get_pending_tenant_linking_subscriptions()
except Exception as e:
logger.error(f"Failed to get pending tenant linking subscriptions: {str(e)}")
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
async def get_pending_subscriptions_by_user(self, user_id: str) -> List[Subscription]:
"""Get pending tenant linking subscriptions for a specific user"""
try:
return await self.subscription_repo.get_pending_subscriptions_by_user(user_id)
except Exception as e:
logger.error("Failed to get pending subscriptions by user",
user_id=user_id, error=str(e))
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
async def link_subscription_to_tenant(
self,
subscription_id: str,
tenant_id: str,
user_id: str
) -> Subscription:
"""
Link a pending subscription to a tenant
This completes the registration flow by associating the subscription
created during registration with the tenant created during onboarding
Args:
subscription_id: Subscription ID to link
tenant_id: Tenant ID to link to
user_id: User ID performing the linking (for validation)
Returns:
Updated Subscription object
"""
try:
return await self.subscription_repo.link_subscription_to_tenant(
subscription_id, tenant_id, user_id
)
except Exception as e:
logger.error("Failed to link subscription to tenant",
subscription_id=subscription_id,
tenant_id=tenant_id,
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to link subscription to tenant: {str(e)}")
async def cleanup_orphaned_subscriptions(self, days_old: int = 30) -> int:
"""Clean up subscriptions that were never linked to tenants"""
try:
return await self.subscription_repo.cleanup_orphaned_subscriptions(days_old)
except Exception as e:
logger.error(f"Failed to cleanup orphaned subscriptions: {str(e)}")
raise DatabaseError(f"Failed to cleanup orphaned subscriptions: {str(e)}")
async def update_subscription_info(
self,
subscription_id: str,
update_data: Dict[str, Any]
) -> Subscription:
"""
Update subscription-related information (3DS flags, status, etc.)
This is useful for updating tenant-independent subscriptions during registration.
Args:
subscription_id: Subscription ID
update_data: Dictionary with fields to update
Returns:
Updated Subscription object
"""
try:
# Filter allowed fields
allowed_fields = {
'plan', 'status', 'is_tenant_linked', 'tenant_linking_status',
'threeds_authentication_required', 'threeds_authentication_required_at',
'threeds_authentication_completed', 'threeds_authentication_completed_at',
'last_threeds_setup_intent_id', 'threeds_action_type'
}
filtered_data = {k: v for k, v in update_data.items() if k in allowed_fields}
if not filtered_data:
logger.warning("No valid subscription info fields provided for update",
subscription_id=subscription_id)
return await self.subscription_repo.get_by_id(subscription_id)
updated_subscription = await self.subscription_repo.update(subscription_id, filtered_data)
if not updated_subscription:
raise ValidationError(f"Subscription not found: {subscription_id}")
logger.info("Subscription info updated",
subscription_id=subscription_id,
updated_fields=list(filtered_data.keys()))
return updated_subscription
except ValidationError:
raise
except Exception as e:
logger.error("Failed to update subscription info",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to update subscription info: {str(e)}")