Add subcription feature

This commit is contained in:
Urtzi Alfaro
2026-01-13 22:22:38 +01:00
parent b931a5c45e
commit 6ddf608d37
61 changed files with 7915 additions and 1238 deletions

View File

@@ -76,16 +76,24 @@ class StripeProvider(PaymentProvider):
plan_id=plan_id,
payment_method_id=payment_method_id)
# Attach payment method to customer with idempotency
stripe.PaymentMethod.attach(
payment_method_id,
customer=customer_id,
idempotency_key=payment_method_idempotency_key
)
logger.info("Payment method attached to customer",
customer_id=customer_id,
payment_method_id=payment_method_id)
# Attach payment method to customer with idempotency and error handling
try:
stripe.PaymentMethod.attach(
payment_method_id,
customer=customer_id,
idempotency_key=payment_method_idempotency_key
)
logger.info("Payment method attached to customer",
customer_id=customer_id,
payment_method_id=payment_method_id)
except stripe.error.InvalidRequestError as e:
# Payment method may already be attached
if 'already been attached' in str(e):
logger.warning("Payment method already attached to customer",
customer_id=customer_id,
payment_method_id=payment_method_id)
else:
raise
# Set customer's default payment method with idempotency
stripe.Customer.modify(
@@ -114,19 +122,36 @@ class StripeProvider(PaymentProvider):
trial_period_days=trial_period_days)
stripe_subscription = stripe.Subscription.create(**subscription_params)
logger.info("Stripe subscription created successfully",
subscription_id=stripe_subscription.id,
status=stripe_subscription.status,
current_period_end=stripe_subscription.current_period_end)
# Handle period dates for trial vs active subscriptions
# During trial: current_period_* fields are only in subscription items, not root
# After trial: current_period_* fields are at root level
if stripe_subscription.status == 'trialing' and stripe_subscription.items and stripe_subscription.items.data:
# For trial subscriptions, get period from first subscription item
first_item = stripe_subscription.items.data[0]
current_period_start = first_item.current_period_start
current_period_end = first_item.current_period_end
logger.info("Stripe trial subscription created successfully",
subscription_id=stripe_subscription.id,
status=stripe_subscription.status,
trial_end=stripe_subscription.trial_end,
current_period_end=current_period_end)
else:
# For active subscriptions, get period from root level
current_period_start = stripe_subscription.current_period_start
current_period_end = stripe_subscription.current_period_end
logger.info("Stripe subscription created successfully",
subscription_id=stripe_subscription.id,
status=stripe_subscription.status,
current_period_end=current_period_end)
return Subscription(
id=stripe_subscription.id,
customer_id=stripe_subscription.customer,
plan_id=plan_id, # Using the price ID as plan_id
status=stripe_subscription.status,
current_period_start=datetime.fromtimestamp(stripe_subscription.current_period_start),
current_period_end=datetime.fromtimestamp(stripe_subscription.current_period_end),
current_period_start=datetime.fromtimestamp(current_period_start),
current_period_end=datetime.fromtimestamp(current_period_end),
created_at=datetime.fromtimestamp(stripe_subscription.created)
)
except stripe.error.CardError as e:
@@ -155,12 +180,24 @@ class StripeProvider(PaymentProvider):
Update the payment method for a customer in Stripe
"""
try:
# Attach payment method to customer
stripe.PaymentMethod.attach(
payment_method_id,
customer=customer_id,
)
# Attach payment method to customer with error handling
try:
stripe.PaymentMethod.attach(
payment_method_id,
customer=customer_id,
)
logger.info("Payment method attached for update",
customer_id=customer_id,
payment_method_id=payment_method_id)
except stripe.error.InvalidRequestError as e:
# Payment method may already be attached
if 'already been attached' in str(e):
logger.warning("Payment method already attached, skipping attach",
customer_id=customer_id,
payment_method_id=payment_method_id)
else:
raise
# Set as default payment method
stripe.Customer.modify(
customer_id,
@@ -183,20 +220,54 @@ class StripeProvider(PaymentProvider):
logger.error("Failed to update Stripe payment method", error=str(e))
raise e
async def cancel_subscription(self, subscription_id: str) -> Subscription:
async def cancel_subscription(
self,
subscription_id: str,
cancel_at_period_end: bool = True
) -> Subscription:
"""
Cancel a subscription in Stripe
Args:
subscription_id: Stripe subscription ID
cancel_at_period_end: If True, subscription continues until end of billing period.
If False, cancels immediately.
Returns:
Updated Subscription object
"""
try:
stripe_subscription = stripe.Subscription.delete(subscription_id)
if cancel_at_period_end:
# Cancel at end of billing period (graceful cancellation)
stripe_subscription = stripe.Subscription.modify(
subscription_id,
cancel_at_period_end=True
)
logger.info("Subscription set to cancel at period end",
subscription_id=subscription_id,
cancel_at=stripe_subscription.trial_end if stripe_subscription.status == 'trialing' else stripe_subscription.current_period_end)
else:
# Cancel immediately
stripe_subscription = stripe.Subscription.delete(subscription_id)
logger.info("Subscription cancelled immediately",
subscription_id=subscription_id)
# Handle period dates for trial vs active subscriptions
if stripe_subscription.status == 'trialing' and stripe_subscription.items and stripe_subscription.items.data:
first_item = stripe_subscription.items.data[0]
current_period_start = first_item.current_period_start
current_period_end = first_item.current_period_end
else:
current_period_start = stripe_subscription.current_period_start
current_period_end = stripe_subscription.current_period_end
return Subscription(
id=stripe_subscription.id,
customer_id=stripe_subscription.customer,
plan_id=subscription_id, # This would need to be retrieved differently in practice
plan_id=subscription_id,
status=stripe_subscription.status,
current_period_start=datetime.fromtimestamp(stripe_subscription.current_period_start),
current_period_end=datetime.fromtimestamp(stripe_subscription.current_period_end),
current_period_start=datetime.fromtimestamp(current_period_start),
current_period_end=datetime.fromtimestamp(current_period_end),
created_at=datetime.fromtimestamp(stripe_subscription.created)
)
except stripe.error.StripeError as e:
@@ -242,19 +313,291 @@ class StripeProvider(PaymentProvider):
"""
try:
stripe_subscription = stripe.Subscription.retrieve(subscription_id)
# Get the actual plan ID from the subscription items
plan_id = subscription_id # Default fallback
if stripe_subscription.items and stripe_subscription.items.data:
plan_id = stripe_subscription.items.data[0].price.id
# Handle period dates for trial vs active subscriptions
# During trial: current_period_* fields are only in subscription items, not root
# After trial: current_period_* fields are at root level
if stripe_subscription.status == 'trialing' and stripe_subscription.items and stripe_subscription.items.data:
# For trial subscriptions, get period from first subscription item
first_item = stripe_subscription.items.data[0]
current_period_start = first_item.current_period_start
current_period_end = first_item.current_period_end
else:
# For active subscriptions, get period from root level
current_period_start = stripe_subscription.current_period_start
current_period_end = stripe_subscription.current_period_end
return Subscription(
id=stripe_subscription.id,
customer_id=stripe_subscription.customer,
plan_id=subscription_id, # This would need to be retrieved differently in practice
plan_id=plan_id,
status=stripe_subscription.status,
current_period_start=datetime.fromtimestamp(stripe_subscription.current_period_start),
current_period_end=datetime.fromtimestamp(stripe_subscription.current_period_end),
created_at=datetime.fromtimestamp(stripe_subscription.created)
current_period_start=datetime.fromtimestamp(current_period_start),
current_period_end=datetime.fromtimestamp(current_period_end),
created_at=datetime.fromtimestamp(stripe_subscription.created),
billing_cycle_anchor=datetime.fromtimestamp(stripe_subscription.billing_cycle_anchor) if stripe_subscription.billing_cycle_anchor else None,
cancel_at_period_end=stripe_subscription.cancel_at_period_end
)
except stripe.error.StripeError as e:
logger.error("Failed to retrieve Stripe subscription", error=str(e))
raise e
async def update_subscription(
self,
subscription_id: str,
new_price_id: str,
proration_behavior: str = "create_prorations",
billing_cycle_anchor: str = "unchanged",
payment_behavior: str = "error_if_incomplete",
immediate_change: bool = False
) -> Subscription:
"""
Update a subscription in Stripe with proration support
Args:
subscription_id: Stripe subscription ID
new_price_id: New Stripe price ID to switch to
proration_behavior: How to handle prorations ('create_prorations', 'none', 'always_invoice')
billing_cycle_anchor: When to apply changes ('unchanged', 'now')
payment_behavior: Payment behavior ('error_if_incomplete', 'allow_incomplete')
immediate_change: Whether to apply changes immediately or at period end
Returns:
Updated Subscription object
"""
try:
logger.info("Updating Stripe subscription",
subscription_id=subscription_id,
new_price_id=new_price_id,
proration_behavior=proration_behavior,
immediate_change=immediate_change)
# Get current subscription to preserve settings
current_subscription = stripe.Subscription.retrieve(subscription_id)
# Build update parameters
update_params = {
'items': [{
'id': current_subscription.items.data[0].id,
'price': new_price_id,
}],
'proration_behavior': proration_behavior,
'billing_cycle_anchor': billing_cycle_anchor,
'payment_behavior': payment_behavior,
'expand': ['latest_invoice.payment_intent']
}
# If not immediate change, set cancel_at_period_end to False
# and let Stripe handle the transition
if not immediate_change:
update_params['cancel_at_period_end'] = False
update_params['proration_behavior'] = 'none' # No proration for end-of-period changes
# Update the subscription
updated_subscription = stripe.Subscription.modify(
subscription_id,
**update_params
)
logger.info("Stripe subscription updated successfully",
subscription_id=updated_subscription.id,
new_price_id=new_price_id,
status=updated_subscription.status)
# Get the actual plan ID from the subscription items
plan_id = new_price_id
if updated_subscription.items and updated_subscription.items.data:
plan_id = updated_subscription.items.data[0].price.id
# Handle period dates for trial vs active subscriptions
if updated_subscription.status == 'trialing' and updated_subscription.items and updated_subscription.items.data:
first_item = updated_subscription.items.data[0]
current_period_start = first_item.current_period_start
current_period_end = first_item.current_period_end
else:
current_period_start = updated_subscription.current_period_start
current_period_end = updated_subscription.current_period_end
return Subscription(
id=updated_subscription.id,
customer_id=updated_subscription.customer,
plan_id=plan_id,
status=updated_subscription.status,
current_period_start=datetime.fromtimestamp(current_period_start),
current_period_end=datetime.fromtimestamp(current_period_end),
created_at=datetime.fromtimestamp(updated_subscription.created),
billing_cycle_anchor=datetime.fromtimestamp(updated_subscription.billing_cycle_anchor) if updated_subscription.billing_cycle_anchor else None,
cancel_at_period_end=updated_subscription.cancel_at_period_end
)
except stripe.error.StripeError as e:
logger.error("Failed to update Stripe subscription",
error=str(e),
subscription_id=subscription_id,
new_price_id=new_price_id)
raise e
async def calculate_proration(
self,
subscription_id: str,
new_price_id: str,
proration_behavior: str = "create_prorations"
) -> Dict[str, Any]:
"""
Calculate proration amounts for a subscription change
Args:
subscription_id: Stripe subscription ID
new_price_id: New Stripe price ID
proration_behavior: Proration behavior to use
Returns:
Dictionary with proration details including amount, currency, and description
"""
try:
logger.info("Calculating proration for subscription change",
subscription_id=subscription_id,
new_price_id=new_price_id)
# Get current subscription
current_subscription = stripe.Subscription.retrieve(subscription_id)
current_price_id = current_subscription.items.data[0].price.id
# Get current and new prices
current_price = stripe.Price.retrieve(current_price_id)
new_price = stripe.Price.retrieve(new_price_id)
# Calculate time remaining in current billing period
current_period_end = datetime.fromtimestamp(current_subscription.current_period_end)
current_period_start = datetime.fromtimestamp(current_subscription.current_period_start)
now = datetime.now(timezone.utc)
total_period_days = (current_period_end - current_period_start).days
remaining_days = (current_period_end - now).days
used_days = (now - current_period_start).days
# Calculate prorated amounts
current_price_amount = current_price.unit_amount / 100.0 # Convert from cents
new_price_amount = new_price.unit_amount / 100.0
# Calculate daily rates
current_daily_rate = current_price_amount / total_period_days
new_daily_rate = new_price_amount / total_period_days
# Calculate proration based on behavior
if proration_behavior == "create_prorations":
# Calculate credit for unused time on current plan
unused_current_amount = current_daily_rate * remaining_days
# Calculate charge for remaining time on new plan
prorated_new_amount = new_daily_rate * remaining_days
# Net amount (could be positive or negative)
net_amount = prorated_new_amount - unused_current_amount
return {
"current_price_amount": current_price_amount,
"new_price_amount": new_price_amount,
"unused_current_amount": unused_current_amount,
"prorated_new_amount": prorated_new_amount,
"net_amount": net_amount,
"currency": current_price.currency.upper(),
"remaining_days": remaining_days,
"used_days": used_days,
"total_period_days": total_period_days,
"description": f"Proration for changing from {current_price_id} to {new_price_id}",
"is_credit": net_amount < 0
}
elif proration_behavior == "none":
return {
"current_price_amount": current_price_amount,
"new_price_amount": new_price_amount,
"net_amount": 0,
"currency": current_price.currency.upper(),
"description": "No proration - changes apply at period end",
"is_credit": False
}
else:
return {
"current_price_amount": current_price_amount,
"new_price_amount": new_price_amount,
"net_amount": new_price_amount - current_price_amount,
"currency": current_price.currency.upper(),
"description": "Full amount difference - immediate billing",
"is_credit": False
}
except stripe.error.StripeError as e:
logger.error("Failed to calculate proration",
error=str(e),
subscription_id=subscription_id,
new_price_id=new_price_id)
raise e
async def change_billing_cycle(
self,
subscription_id: str,
new_billing_cycle: str,
proration_behavior: str = "create_prorations"
) -> Subscription:
"""
Change billing cycle (monthly ↔ yearly) for a subscription
Args:
subscription_id: Stripe subscription ID
new_billing_cycle: New billing cycle ('monthly' or 'yearly')
proration_behavior: Proration behavior to use
Returns:
Updated Subscription object
"""
try:
logger.info("Changing billing cycle for subscription",
subscription_id=subscription_id,
new_billing_cycle=new_billing_cycle)
# Get current subscription
current_subscription = stripe.Subscription.retrieve(subscription_id)
current_price_id = current_subscription.items.data[0].price.id
# Get current price to determine the plan
current_price = stripe.Price.retrieve(current_price_id)
product_id = current_price.product
# Find the corresponding price for the new billing cycle
# This assumes you have price IDs set up for both monthly and yearly
# You would need to map this based on your product catalog
prices = stripe.Price.list(product=product_id, active=True)
new_price_id = None
for price in prices:
if price.recurring and price.recurring.interval == new_billing_cycle:
new_price_id = price.id
break
if not new_price_id:
raise ValueError(f"No {new_billing_cycle} price found for product {product_id}")
# Update the subscription with the new price
return await self.update_subscription(
subscription_id,
new_price_id,
proration_behavior=proration_behavior,
billing_cycle_anchor="now",
immediate_change=True
)
except stripe.error.StripeError as e:
logger.error("Failed to change billing cycle",
error=str(e),
subscription_id=subscription_id,
new_billing_cycle=new_billing_cycle)
raise e
async def get_customer(self, customer_id: str) -> PaymentCustomer:
"""