Add subcription feature 9

This commit is contained in:
Urtzi Alfaro
2026-01-16 20:25:45 +01:00
parent fa7b62bd6c
commit 3a7d57ef90
19 changed files with 1833 additions and 985 deletions

View File

@@ -737,15 +737,27 @@ async def validate_plan_upgrade(
async def upgrade_subscription_plan(
tenant_id: str = Path(..., description="Tenant ID"),
new_plan: str = Query(..., description="New plan name"),
billing_cycle: str = Query("monthly", description="Billing cycle (monthly/yearly)"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
limit_service: SubscriptionLimitService = Depends(get_subscription_limit_service)
limit_service: SubscriptionLimitService = Depends(get_subscription_limit_service),
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service)
) -> Dict[str, Any]:
"""
Upgrade subscription plan for a tenant
Includes validation, cache invalidation, and token refresh.
This endpoint handles:
- Plan upgrade validation
- Stripe subscription update (preserves trial status if in trial)
- Local database update
- Cache invalidation
- Token refresh for immediate UI update
Trial handling:
- If user is in trial, they remain in trial after upgrade
- The upgraded tier price will be charged when trial ends
"""
try:
# Step 1: Validate upgrade eligibility
validation = await limit_service.validate_plan_upgrade(tenant_id, new_plan)
if not validation.get("can_upgrade", False):
raise HTTPException(
@@ -768,22 +780,77 @@ async def upgrade_subscription_plan(
detail="No active subscription found for this tenant"
)
old_plan = active_subscription.plan
is_trialing = active_subscription.status == 'trialing'
trial_ends_at = active_subscription.trial_ends_at
logger.info("Starting subscription upgrade",
extra={
"tenant_id": tenant_id,
"subscription_id": str(active_subscription.id),
"stripe_subscription_id": active_subscription.subscription_id,
"old_plan": old_plan,
"new_plan": new_plan,
"is_trialing": is_trialing,
"trial_ends_at": str(trial_ends_at) if trial_ends_at else None,
"user_id": current_user["user_id"]
})
# Step 2: Update Stripe subscription if Stripe subscription ID exists
stripe_updated = False
if active_subscription.subscription_id:
try:
# Use orchestration service to handle Stripe update with trial preservation
upgrade_result = await orchestration_service.orchestrate_plan_upgrade(
tenant_id=tenant_id,
new_plan=new_plan,
proration_behavior="none" if is_trialing else "create_prorations",
immediate_change=not is_trialing, # Don't change billing anchor if trialing
billing_cycle=billing_cycle
)
stripe_updated = True
logger.info("Stripe subscription updated successfully",
extra={
"tenant_id": tenant_id,
"stripe_subscription_id": active_subscription.subscription_id,
"upgrade_result": upgrade_result
})
except Exception as stripe_error:
logger.error("Failed to update Stripe subscription, falling back to local update only",
extra={"tenant_id": tenant_id, "error": str(stripe_error)})
# Continue with local update even if Stripe fails
# This ensures the user gets access to features immediately
# Step 3: Update local database
updated_subscription = await subscription_repo.update_subscription_plan(
str(active_subscription.id),
new_plan
)
# Preserve trial status if was trialing
if is_trialing and trial_ends_at:
# Ensure trial_ends_at is preserved after plan update
await subscription_repo.update_subscription_status(
str(active_subscription.id),
'trialing',
{'trial_ends_at': trial_ends_at}
)
await session.commit()
logger.info("Subscription plan upgraded successfully",
logger.info("Subscription plan upgraded successfully in database",
extra={
"tenant_id": tenant_id,
"subscription_id": str(active_subscription.id),
"old_plan": active_subscription.plan,
"old_plan": old_plan,
"new_plan": new_plan,
"stripe_updated": stripe_updated,
"preserved_trial": is_trialing,
"user_id": current_user["user_id"]
})
# Step 4: Invalidate subscription cache
redis_client = None
try:
from app.services.subscription_cache import get_subscription_cache_service
@@ -797,14 +864,17 @@ async def upgrade_subscription_plan(
logger.error("Failed to invalidate subscription cache after upgrade",
extra={"tenant_id": tenant_id, "error": str(cache_error)})
# Step 5: Invalidate tokens for immediate UI refresh
try:
await _invalidate_tenant_tokens(tenant_id, redis_client)
logger.info("Invalidated all tokens for tenant after subscription upgrade",
extra={"tenant_id": tenant_id})
if redis_client:
await _invalidate_tenant_tokens(tenant_id, redis_client)
logger.info("Invalidated all tokens for tenant after subscription upgrade",
extra={"tenant_id": tenant_id})
except Exception as token_error:
logger.error("Failed to invalidate tenant tokens after upgrade",
extra={"tenant_id": tenant_id, "error": str(token_error)})
# Step 6: Publish subscription change event for other services
try:
from shared.messaging import UnifiedEventPublisher
event_publisher = UnifiedEventPublisher()
@@ -813,9 +883,12 @@ async def upgrade_subscription_plan(
tenant_id=tenant_id,
data={
"tenant_id": tenant_id,
"old_tier": active_subscription.plan,
"old_tier": old_plan,
"new_tier": new_plan,
"action": "upgrade"
"action": "upgrade",
"is_trialing": is_trialing,
"trial_ends_at": trial_ends_at.isoformat() if trial_ends_at else None,
"stripe_updated": stripe_updated
}
)
logger.info("Published subscription change event",
@@ -826,10 +899,13 @@ async def upgrade_subscription_plan(
return {
"success": True,
"message": f"Plan successfully upgraded to {new_plan}",
"old_plan": active_subscription.plan,
"message": f"Plan successfully upgraded to {new_plan}" + (" (trial preserved)" if is_trialing else ""),
"old_plan": old_plan,
"new_plan": new_plan,
"new_monthly_price": updated_subscription.monthly_price,
"is_trialing": is_trialing,
"trial_ends_at": trial_ends_at.isoformat() if trial_ends_at else None,
"stripe_updated": stripe_updated,
"validation": validation,
"requires_token_refresh": True
}

View File

@@ -114,10 +114,12 @@ async def register_bakery(
error=str(linking_error))
elif bakery_data.coupon_code:
coupon_validation = payment_service.validate_coupon_code(
from app.services.coupon_service import CouponService
coupon_service = CouponService(db)
coupon_validation = await coupon_service.validate_coupon_code(
bakery_data.coupon_code,
tenant_id,
db
tenant_id
)
if not coupon_validation["valid"]:
@@ -131,10 +133,10 @@ async def register_bakery(
detail=coupon_validation["error_message"]
)
success, discount, error = payment_service.redeem_coupon(
success, discount, error = await coupon_service.redeem_coupon(
bakery_data.coupon_code,
tenant_id,
db
base_trial_days=0
)
if success:
@@ -194,13 +196,15 @@ async def register_bakery(
if coupon_validation and coupon_validation["valid"]:
from app.core.config import settings
from app.services.coupon_service import CouponService
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
async with database_manager.get_session() as session:
success, discount, error = payment_service.redeem_coupon(
coupon_service = CouponService(session)
success, discount, error = await coupon_service.redeem_coupon(
bakery_data.coupon_code,
result.id,
session
base_trial_days=0
)
if success:

View File

@@ -247,6 +247,50 @@ class SubscriptionRepository(TenantBaseRepository):
error=str(e))
raise DatabaseError(f"Failed to update plan: {str(e)}")
async def update_subscription_status(
self,
subscription_id: str,
status: str,
additional_data: Dict[str, Any] = None
) -> Optional[Subscription]:
"""Update subscription status with optional additional data"""
try:
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
update_data = {
"status": status,
"updated_at": datetime.utcnow()
}
# Merge additional data if provided
if additional_data:
update_data.update(additional_data)
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
if subscription.tenant_id:
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription status updated",
subscription_id=subscription_id,
new_status=status,
additional_data=additional_data)
return updated_subscription
except ValidationError:
raise
except Exception as e:
logger.error("Failed to update subscription status",
subscription_id=subscription_id,
status=status,
error=str(e))
raise DatabaseError(f"Failed to update subscription status: {str(e)}")
async def cancel_subscription(
self,
subscription_id: str,

View File

@@ -852,7 +852,8 @@ class PaymentService:
proration_behavior: str = "create_prorations",
billing_cycle_anchor: Optional[str] = None,
payment_behavior: str = "error_if_incomplete",
immediate_change: bool = True
immediate_change: bool = True,
preserve_trial: bool = False
) -> Any:
"""
Update subscription price (plan upgrade/downgrade)
@@ -860,24 +861,33 @@ class PaymentService:
Args:
subscription_id: Stripe subscription ID
new_price_id: New Stripe price ID
proration_behavior: How to handle proration
proration_behavior: How to handle proration ('create_prorations', 'none', 'always_invoice')
billing_cycle_anchor: Billing cycle anchor ("now" or "unchanged")
payment_behavior: Payment behavior on update
immediate_change: Whether to apply changes immediately
preserve_trial: If True, preserves the trial period after upgrade
Returns:
Updated subscription object with .id, .status, etc. attributes
"""
try:
result = await retry_with_backoff(
lambda: self.stripe_client.update_subscription(subscription_id, new_price_id),
lambda: self.stripe_client.update_subscription(
subscription_id,
new_price_id,
proration_behavior=proration_behavior,
preserve_trial=preserve_trial
),
max_retries=3,
exceptions=(SubscriptionCreationFailed,)
)
logger.info("Subscription updated successfully",
subscription_id=subscription_id,
new_price_id=new_price_id)
new_price_id=new_price_id,
proration_behavior=proration_behavior,
preserve_trial=preserve_trial,
is_trialing=result.get('is_trialing', False))
# Create wrapper object for compatibility with callers expecting .id, .status etc.
class SubscriptionWrapper:
@@ -887,6 +897,8 @@ class PaymentService:
self.current_period_start = data.get('current_period_start')
self.current_period_end = data.get('current_period_end')
self.customer = data.get('customer_id')
self.trial_end = data.get('trial_end')
self.is_trialing = data.get('is_trialing', False)
return SubscriptionWrapper(result)

View File

@@ -638,17 +638,22 @@ class SubscriptionOrchestrationService:
billing_cycle: str = "monthly"
) -> Dict[str, Any]:
"""
Orchestrate plan upgrade workflow with proration
Orchestrate plan upgrade workflow with proration and trial preservation
Args:
tenant_id: Tenant ID
new_plan: New plan name
proration_behavior: Proration behavior
proration_behavior: Proration behavior ('create_prorations', 'none', 'always_invoice')
immediate_change: Whether to apply changes immediately
billing_cycle: Billing cycle for new plan
Returns:
Dictionary with upgrade results
Trial Handling:
- If subscription is in trial, the trial period is preserved
- No proration charges are created during trial
- After trial ends, the user is charged at the new tier price
"""
try:
logger.info("Starting plan upgrade orchestration",
@@ -665,33 +670,64 @@ class SubscriptionOrchestrationService:
if not subscription.subscription_id:
raise ValidationError(f"Tenant {tenant_id} does not have a Stripe subscription ID")
# Step 1.5: Check if subscription is in trial
is_trialing = subscription.status == 'trialing'
trial_ends_at = subscription.trial_ends_at
logger.info("Subscription trial status",
tenant_id=tenant_id,
is_trialing=is_trialing,
trial_ends_at=str(trial_ends_at) if trial_ends_at else None,
current_plan=subscription.plan)
# For trial subscriptions:
# - No proration charges (proration_behavior='none')
# - Preserve trial period
# - User gets new tier features immediately
# - User is charged new tier price when trial ends
if is_trialing:
proration_behavior = "none"
logger.info("Trial subscription detected, disabling proration",
tenant_id=tenant_id)
# Step 2: Get Stripe price ID for new plan
stripe_price_id = self.payment_service._get_stripe_price_id(new_plan, billing_cycle)
# Step 3: Calculate proration preview
proration_details = await self.payment_service.calculate_payment_proration(
subscription.subscription_id,
stripe_price_id,
proration_behavior
)
# Step 3: Calculate proration preview (only if not trialing)
proration_details = {}
if not is_trialing:
proration_details = await self.payment_service.calculate_payment_proration(
subscription.subscription_id,
stripe_price_id,
proration_behavior
)
logger.info("Proration calculated for plan upgrade",
tenant_id=tenant_id,
proration_amount=proration_details.get("net_amount", 0))
else:
proration_details = {
"subscription_id": subscription.subscription_id,
"new_price_id": stripe_price_id,
"proration_behavior": proration_behavior,
"net_amount": 0,
"trial_preserved": True
}
logger.info("Proration calculated for plan upgrade",
tenant_id=tenant_id,
proration_amount=proration_details.get("net_amount", 0))
# Step 4: Update in payment provider
# Step 4: Update in payment provider with trial preservation
updated_stripe_subscription = await self.payment_service.update_payment_subscription(
subscription.subscription_id,
stripe_price_id,
proration_behavior=proration_behavior,
billing_cycle_anchor="now" if immediate_change else "unchanged",
billing_cycle_anchor="now" if immediate_change and not is_trialing else "unchanged",
payment_behavior="error_if_incomplete",
immediate_change=immediate_change
immediate_change=immediate_change,
preserve_trial=is_trialing # Preserve trial if currently trialing
)
logger.info("Plan updated in payment provider",
subscription_id=updated_stripe_subscription.id,
new_status=updated_stripe_subscription.status)
new_status=updated_stripe_subscription.status,
is_trialing=getattr(updated_stripe_subscription, 'is_trialing', False))
# Step 5: Update local subscription record
update_result = await self.subscription_service.update_subscription_plan_record(
@@ -722,8 +758,12 @@ class SubscriptionOrchestrationService:
logger.info("Tenant plan information updated",
tenant_id=tenant_id)
# Add immediate_change to result
# Add upgrade metadata to result
update_result["immediate_change"] = immediate_change
update_result["is_trialing"] = is_trialing
update_result["trial_preserved"] = is_trialing
if trial_ends_at:
update_result["trial_ends_at"] = trial_ends_at.isoformat()
return update_result