Files
bakery-ia/services/tenant/app/api/subscription.py
2026-01-16 20:25:45 +01:00

1265 lines
52 KiB
Python

"""
Subscription API - All subscription-related endpoints
This module contains all subscription-related endpoints following the new architecture:
PUBLIC ENDPOINTS (No Authentication):
- GET /api/v1/plans - Available from plans.py router
REGISTRATION FLOW (No Tenant Context):
- POST /api/v1/registration/payment-setup - Start registration with payment
- POST /api/v1/registration/complete - Complete registration after 3DS
- GET /api/v1/registration/state/{state_id} - Check registration state
TENANT-DEPENDENT SUBSCRIPTION ENDPOINTS:
- GET /api/v1/tenants/{tenant_id}/subscription/status - Get subscription status
- GET /api/v1/tenants/{tenant_id}/subscription/details - Get full subscription details
- GET /api/v1/tenants/{tenant_id}/subscription/tier - Get subscription tier (cached)
- GET /api/v1/tenants/{tenant_id}/subscription/limits - Get subscription limits
- GET /api/v1/tenants/{tenant_id}/subscription/usage - Get usage summary
- GET /api/v1/tenants/{tenant_id}/subscription/features/{feature} - Check feature access
SUBSCRIPTION MANAGEMENT:
- POST /api/v1/tenants/{tenant_id}/subscription/cancel - Cancel subscription
- POST /api/v1/tenants/{tenant_id}/subscription/reactivate - Reactivate subscription
- GET /api/v1/tenants/{tenant_id}/subscription/validate-upgrade/{new_plan} - Validate upgrade
- POST /api/v1/tenants/{tenant_id}/subscription/upgrade - Upgrade subscription
QUOTA & LIMIT CHECKS:
- GET /api/v1/tenants/{tenant_id}/subscription/limits/locations - Check location limits
- GET /api/v1/tenants/{tenant_id}/subscription/limits/products - Check product limits
- GET /api/v1/tenants/{tenant_id}/subscription/limits/users - Check user limits
- GET /api/v1/tenants/{tenant_id}/subscription/limits/recipes - Check recipe limits
- GET /api/v1/tenants/{tenant_id}/subscription/limits/suppliers - Check supplier limits
PAYMENT MANAGEMENT:
- GET /api/v1/tenants/{tenant_id}/subscription/payment-method - Get payment method
- POST /api/v1/tenants/{tenant_id}/subscription/payment-method - Update payment method
- GET /api/v1/tenants/{tenant_id}/subscription/invoices - Get invoices
"""
import logging
import json
from typing import Dict, Any, Optional
from datetime import datetime, timezone
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Request, Query, Path
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.services.subscription_orchestration_service import SubscriptionOrchestrationService
from app.services.subscription_limit_service import SubscriptionLimitService
from app.services.coupon_service import CouponService
from app.core.database import get_db
from app.core.config import settings
from app.models.tenants import Subscription
from app.services.registration_state_service import (
registration_state_service,
RegistrationStateService,
RegistrationState
)
from shared.exceptions.payment_exceptions import (
PaymentServiceError,
SetupIntentError,
SubscriptionCreationFailed,
)
from shared.exceptions.registration_exceptions import (
RegistrationStateError,
)
from shared.auth.decorators import get_current_user_dep
from shared.database.base import create_database_manager
import shared.redis_utils
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1", tags=["subscription"])
# Global Redis client
_redis_client = None
# ============================================================================
# DEPENDENCY INJECTION
# ============================================================================
async def get_subscription_orchestration_service(
db: AsyncSession = Depends(get_db)
) -> SubscriptionOrchestrationService:
"""Dependency injection for subscription orchestration service"""
return SubscriptionOrchestrationService(db)
async def get_registration_state_service() -> RegistrationStateService:
"""Dependency injection for registration state service"""
return registration_state_service
async def get_subscription_redis_client():
"""Get or create Redis client"""
global _redis_client
try:
if _redis_client is None:
_redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
logger.info("Redis client initialized for subscription service")
return _redis_client
except Exception as e:
logger.warning("Failed to initialize Redis client", extra={"error": str(e)})
return None
def get_subscription_limit_service():
"""Create subscription limit service instance"""
try:
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
return SubscriptionLimitService(database_manager, None)
except Exception as e:
logger.error("Failed to create subscription limit service", extra={"error": str(e)})
raise HTTPException(status_code=500, detail="Service initialization failed")
# ============================================================================
# REGISTRATION FLOW ENDPOINTS (No Tenant Context)
# ============================================================================
@router.post("/registration/payment-setup",
response_model=Dict[str, Any],
summary="Start registration payment setup")
async def create_registration_payment_setup(
user_data: Dict[str, Any],
request: Request,
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service),
state_service: RegistrationStateService = Depends(get_registration_state_service)
) -> Dict[str, Any]:
"""
Start registration payment setup (SetupIntent-first architecture).
Creates customer + SetupIntent only (NO subscription).
Subscription is created in /registration/complete after 3DS verification.
Args:
user_data: User registration data with payment info
- email (required)
- payment_method_id (required)
- plan_id (required)
- billing_cycle (optional, defaults to 'monthly')
- coupon_code (optional)
Returns:
SetupIntent data for frontend confirmation
"""
state_id = None
try:
logger.info("Registration payment setup started",
extra={"email": user_data.get('email'), "plan_id": user_data.get('plan_id')})
if not user_data.get('email'):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email is required")
if not user_data.get('payment_method_id'):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Payment method ID is required")
if not user_data.get('plan_id'):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Plan ID is required")
state_id = await state_service.create_registration_state(
email=user_data['email'],
user_data=user_data
)
result = await orchestration_service.create_registration_payment_setup(
user_data=user_data,
plan_id=user_data.get('plan_id', 'professional'),
payment_method_id=user_data.get('payment_method_id'),
billing_interval=user_data.get('billing_cycle', 'monthly'),
coupon_code=user_data.get('coupon_code')
)
await state_service.update_state_context(state_id, {
'setup_intent_id': result.get('setup_intent_id'),
'customer_id': result.get('customer_id'),
'payment_method_id': result.get('payment_method_id'),
'plan_id': result.get('plan_id'),
'billing_interval': result.get('billing_interval'),
'trial_period_days': result.get('trial_period_days'),
'coupon_code': result.get('coupon_code')
})
await state_service.transition_state(state_id, RegistrationState.PAYMENT_VERIFICATION_PENDING)
logger.info("Registration payment setup completed",
extra={
"email": user_data.get('email'),
"setup_intent_id": result.get('setup_intent_id'),
"requires_action": result.get('requires_action')
})
return {
"success": True,
"requires_action": result.get('requires_action', True),
"action_type": result.get('action_type', 'use_stripe_sdk'),
"client_secret": result.get('client_secret'),
"setup_intent_id": result.get('setup_intent_id'),
"customer_id": result.get('customer_id'),
"payment_customer_id": result.get('customer_id'),
"plan_id": result.get('plan_id'),
"payment_method_id": result.get('payment_method_id'),
"trial_period_days": result.get('trial_period_days', 0),
"billing_cycle": result.get('billing_interval'),
"email": result.get('email'),
"state_id": state_id,
"message": result.get('message', 'Payment verification required')
}
except PaymentServiceError as e:
logger.error(f"Payment setup failed: {str(e)}", extra={"email": user_data.get('email')}, exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Payment setup failed: {str(e)}") from e
except RegistrationStateError as e:
logger.error(f"Registration state error: {str(e)}", extra={"email": user_data.get('email')}, exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Registration state error: {str(e)}") from e
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", extra={"email": user_data.get('email')}, exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Registration failed: {str(e)}") from e
@router.post("/registration/complete",
response_model=Dict[str, Any],
summary="Complete registration after 3DS verification")
async def verify_and_complete_registration(
verification_data: Dict[str, Any],
request: Request,
db: AsyncSession = Depends(get_db),
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service),
state_service: RegistrationStateService = Depends(get_registration_state_service)
) -> Dict[str, Any]:
"""
Complete registration after frontend confirms SetupIntent.
Creates subscription AFTER payment verification is complete.
This is the ONLY place subscriptions are created during registration.
Args:
verification_data: SetupIntent verification data with user_data
- setup_intent_id (required)
- user_data (required)
- state_id (optional)
Returns:
Subscription creation result
"""
setup_intent_id = None
user_data = {}
state_id = None
try:
if not verification_data.get('setup_intent_id'):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="SetupIntent ID is required")
if not verification_data.get('user_data'):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User data is required")
setup_intent_id = verification_data['setup_intent_id']
user_data = verification_data['user_data']
state_id = verification_data.get('state_id')
logger.info("Completing registration after verification",
extra={"email": user_data.get('email'), "setup_intent_id": setup_intent_id})
trial_period_days = 0
coupon_code = user_data.get('coupon_code')
if coupon_code:
logger.info("Validating coupon in completion call",
extra={"coupon_code": coupon_code, "email": user_data.get('email')})
coupon_service = CouponService(db)
success, discount_applied, error = await coupon_service.redeem_coupon(
coupon_code,
None,
base_trial_days=0
)
if success and discount_applied:
trial_period_days = discount_applied.get("total_trial_days", 0)
logger.info("Coupon validated in completion call",
extra={"coupon_code": coupon_code, "trial_period_days": trial_period_days})
else:
logger.warning("Failed to validate coupon in completion call",
extra={"coupon_code": coupon_code, "error": error})
elif 'trial_period_days' in user_data:
trial_period_days = int(user_data.get('trial_period_days', 0))
logger.info("Using explicitly provided trial period",
extra={"trial_period_days": trial_period_days})
result = await orchestration_service.complete_registration_subscription(
setup_intent_id=setup_intent_id,
customer_id=user_data.get('customer_id', ''),
plan_id=user_data.get('plan_id') or user_data.get('subscription_plan', 'professional'),
payment_method_id=user_data.get('payment_method_id', ''),
billing_interval=user_data.get('billing_cycle') or user_data.get('billing_interval', 'monthly'),
trial_period_days=trial_period_days,
user_id=user_data.get('user_id')
)
if state_id:
try:
await state_service.update_state_context(state_id, {
'subscription_id': result['subscription_id'],
'status': result['status']
})
await state_service.transition_state(state_id, RegistrationState.SUBSCRIPTION_CREATED)
except Exception as e:
logger.warning(f"Failed to update registration state: {e}", extra={"state_id": state_id})
logger.info("Registration subscription created successfully",
extra={
"email": user_data.get('email'),
"subscription_id": result['subscription_id'],
"status": result['status']
})
return {
"success": True,
"subscription_id": result['subscription_id'],
"customer_id": result['customer_id'],
"payment_customer_id": result.get('payment_customer_id', result['customer_id']),
"status": result['status'],
"plan_id": result.get('plan_id'),
"payment_method_id": result.get('payment_method_id'),
"trial_period_days": result.get('trial_period_days', 0),
"current_period_end": result.get('current_period_end'),
"state_id": state_id,
"message": "Subscription created successfully"
}
except SetupIntentError as e:
logger.error(f"SetupIntent verification failed: {e}",
extra={"setup_intent_id": setup_intent_id, "email": user_data.get('email')},
exc_info=True)
if state_id:
try:
await state_service.mark_registration_failed(state_id, f"Verification failed: {e}")
except Exception:
pass
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Payment verification failed: {e}") from e
except SubscriptionCreationFailed as e:
logger.error(f"Subscription creation failed: {e}",
extra={"setup_intent_id": setup_intent_id, "email": user_data.get('email')},
exc_info=True)
if state_id:
try:
await state_service.mark_registration_failed(state_id, f"Subscription failed: {e}")
except Exception:
pass
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Subscription creation failed: {e}") from e
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error: {e}",
extra={"setup_intent_id": setup_intent_id, "email": user_data.get('email')},
exc_info=True)
if state_id:
try:
await state_service.mark_registration_failed(state_id, f"Registration failed: {e}")
except Exception:
pass
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Registration failed: {e}") from e
@router.get("/registration/state/{state_id}",
response_model=Dict[str, Any],
summary="Get registration state")
async def get_registration_state(
state_id: str = Path(..., description="Registration state ID"),
request: Request = None,
state_service: RegistrationStateService = Depends(get_registration_state_service)
) -> Dict[str, Any]:
"""
Get registration state by ID
Args:
state_id: Registration state ID
Returns:
Registration state data
Raises:
HTTPException: 404 if state not found, 500 for server errors
"""
try:
logger.info("Getting registration state", extra={"state_id": state_id})
state_data = await state_service.get_registration_state(state_id)
logger.info("Registration state retrieved",
extra={
"state_id": state_id,
"current_state": state_data['current_state']
})
return {
"state_id": state_data['state_id'],
"email": state_data['email'],
"current_state": state_data['current_state'],
"created_at": state_data['created_at'],
"updated_at": state_data['updated_at'],
"setup_intent_id": state_data.get('setup_intent_id'),
"customer_id": state_data.get('customer_id'),
"subscription_id": state_data.get('subscription_id'),
"error": state_data.get('error'),
"user_data": state_data.get('user_data')
}
except RegistrationStateError as e:
logger.error("Registration state not found",
extra={"error": str(e), "state_id": state_id})
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Registration state not found: {str(e)}"
) from e
except Exception as e:
logger.error("Unexpected error getting registration state",
extra={"error": str(e), "state_id": state_id},
exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get registration state: {str(e)}"
) from e
# ============================================================================
# TENANT SUBSCRIPTION STATUS ENDPOINTS
# ============================================================================
@router.get("/tenants/{tenant_id}/subscription/status",
response_model=Dict[str, Any],
summary="Get subscription status")
async def get_subscription_status(
tenant_id: str = Path(..., description="Tenant ID"),
db: AsyncSession = Depends(get_db)
) -> Dict[str, Any]:
"""
Get subscription status for read-only mode enforcement
"""
try:
result = await db.execute(
select(Subscription).where(
Subscription.tenant_id == tenant_id,
Subscription.status == "active"
)
)
subscription = result.scalars().first()
if subscription:
return {
"status": subscription.status,
"plan": subscription.plan,
"is_read_only": False,
"cancellation_effective_date": subscription.cancellation_effective_date.isoformat() if subscription.cancellation_effective_date else None
}
else:
return {
"status": "inactive",
"plan": None,
"is_read_only": True,
"cancellation_effective_date": None
}
except Exception as e:
logger.error(f"Failed to get subscription status: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get subscription status: {str(e)}"
) from e
@router.get("/tenants/{tenant_id}/subscription/details",
response_model=Dict[str, Any],
summary="Get full subscription details")
async def get_subscription_details(
tenant_id: str = Path(..., description="Tenant ID"),
redis_client=Depends(get_subscription_redis_client)
) -> Dict[str, Any]:
"""
Get full active subscription with caching (10-minute cache)
"""
try:
from app.services.subscription_cache import get_subscription_cache_service
cache_service = get_subscription_cache_service(redis_client)
subscription = await cache_service.get_tenant_subscription_cached(tenant_id)
if not subscription:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No active subscription found"
)
return subscription
except HTTPException:
raise
except Exception as e:
logger.error("Failed to get subscription details",
extra={"tenant_id": tenant_id, "error": str(e)})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get subscription details"
)
@router.get("/tenants/{tenant_id}/subscription/tier",
response_model=Dict[str, Any],
summary="Get subscription tier (cached)")
async def get_subscription_tier(
tenant_id: str = Path(..., description="Tenant ID"),
redis_client=Depends(get_subscription_redis_client)
) -> Dict[str, Any]:
"""
Fast cached lookup for tenant subscription tier
Optimized for high-frequency access (e.g., from gateway middleware)
with Redis caching (10-minute TTL).
"""
try:
from app.services.subscription_cache import get_subscription_cache_service
cache_service = get_subscription_cache_service(redis_client)
tier = await cache_service.get_tenant_tier_cached(tenant_id)
return {
"tenant_id": tenant_id,
"tier": tier,
"cached": True
}
except Exception as e:
logger.error("Failed to get subscription tier",
extra={"tenant_id": tenant_id, "error": str(e)})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get subscription tier"
)
@router.get("/tenants/{tenant_id}/subscription/limits",
response_model=Dict[str, Any],
summary="Get subscription limits")
async def get_subscription_limits(
tenant_id: str = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
limit_service: SubscriptionLimitService = Depends(get_subscription_limit_service)
) -> Dict[str, Any]:
"""Get current subscription limits for a tenant"""
try:
limits = await limit_service.get_tenant_subscription_limits(tenant_id)
return limits
except Exception as e:
logger.error("Failed to get subscription limits",
extra={"tenant_id": tenant_id, "error": str(e)})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get subscription limits"
)
@router.get("/tenants/{tenant_id}/subscription/usage",
response_model=Dict[str, Any],
summary="Get usage summary")
async def get_usage_summary(
tenant_id: str = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
limit_service: SubscriptionLimitService = Depends(get_subscription_limit_service)
) -> Dict[str, Any]:
"""Get usage summary vs limits for a tenant (cached for 30s for performance)"""
try:
from shared.redis_utils import get_redis_client
cache_key = f"usage_summary:{tenant_id}"
redis_client = await get_redis_client()
if redis_client:
cached = await redis_client.get(cache_key)
if cached:
logger.debug("Usage summary cache hit", extra={"tenant_id": tenant_id})
return json.loads(cached)
usage = await limit_service.get_usage_summary(tenant_id)
if redis_client:
await redis_client.setex(cache_key, 30, json.dumps(usage))
logger.debug("Usage summary cached", extra={"tenant_id": tenant_id})
return usage
except Exception as e:
logger.error("Failed to get usage summary",
extra={"tenant_id": tenant_id, "error": str(e)})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get usage summary"
)
@router.get("/tenants/{tenant_id}/subscription/features/{feature}",
response_model=Dict[str, Any],
summary="Check feature access")
async def check_feature_access(
tenant_id: str = Path(..., description="Tenant ID"),
feature: str = Path(..., description="Feature name"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
limit_service: SubscriptionLimitService = Depends(get_subscription_limit_service)
) -> Dict[str, Any]:
"""Check if tenant has access to a specific feature"""
try:
result = await limit_service.has_feature(tenant_id, feature)
return result
except Exception as e:
logger.error("Failed to check feature access",
extra={"tenant_id": tenant_id, "feature": feature, "error": str(e)})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to check feature access"
)
# ============================================================================
# SUBSCRIPTION MANAGEMENT ENDPOINTS
# ============================================================================
@router.post("/tenants/{tenant_id}/subscription/cancel",
response_model=Dict[str, Any],
summary="Cancel subscription")
async def cancel_subscription(
tenant_id: str = Path(..., description="Tenant ID"),
request: Request = None,
reason: str = Query("", description="Cancellation reason"),
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service)
) -> Dict[str, Any]:
"""
Cancel a subscription and set to read-only mode
This endpoint allows users to cancel their subscription, which will:
- Mark subscription as pending cancellation
- Set read-only mode effective date to end of current billing period
- Allow read-only access until the end of paid period
"""
try:
result = await orchestration_service.orchestrate_subscription_cancellation(
tenant_id,
reason
)
return {
"success": True,
"message": result.get("message", "Subscription cancellation initiated"),
"status": result.get("status", "pending_cancellation"),
"cancellation_effective_date": result.get("cancellation_effective_date"),
"days_remaining": result.get("days_remaining"),
"read_only_mode_starts": result.get("read_only_mode_starts")
}
except Exception as e:
logger.error("Failed to cancel subscription",
extra={"error": str(e), "tenant_id": tenant_id},
exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to cancel subscription: {str(e)}"
) from e
@router.post("/tenants/{tenant_id}/subscription/reactivate",
response_model=Dict[str, Any],
summary="Reactivate subscription")
async def reactivate_subscription(
tenant_id: str = Path(..., description="Tenant ID"),
plan: str = Query("starter", description="Plan to reactivate to"),
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service)
) -> Dict[str, Any]:
"""
Reactivate a cancelled or inactive subscription
"""
try:
result = await orchestration_service.orchestrate_subscription_reactivation(
tenant_id,
plan
)
return {
"success": True,
"message": result.get("message", "Subscription reactivated successfully"),
"status": result.get("status", "active"),
"plan": result.get("plan", plan),
"next_billing_date": result.get("next_billing_date")
}
except Exception as e:
logger.error("Failed to reactivate subscription",
extra={"error": str(e), "tenant_id": tenant_id},
exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to reactivate subscription: {str(e)}"
) from e
@router.get("/tenants/{tenant_id}/subscription/validate-upgrade/{new_plan}",
response_model=Dict[str, Any],
summary="Validate plan upgrade eligibility")
async def validate_plan_upgrade(
tenant_id: str = Path(..., description="Tenant ID"),
new_plan: str = Path(..., description="New plan name"),
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service)
) -> Dict[str, Any]:
"""
Validate if a tenant can upgrade to a new plan
Checks plan hierarchy and subscription status before allowing upgrade.
"""
try:
validation = await orchestration_service.validate_plan_upgrade(tenant_id, new_plan)
return validation
except Exception as e:
logger.error("Failed to validate plan upgrade",
extra={"error": str(e), "tenant_id": tenant_id, "new_plan": new_plan},
exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to validate plan upgrade: {str(e)}"
) from e
@router.post("/tenants/{tenant_id}/subscription/upgrade",
response_model=Dict[str, Any],
summary="Upgrade subscription plan")
async def upgrade_subscription_plan(
tenant_id: str = Path(..., description="Tenant ID"),
new_plan: str = Query(..., description="New plan name"),
billing_cycle: str = Query("monthly", description="Billing cycle (monthly/yearly)"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
limit_service: SubscriptionLimitService = Depends(get_subscription_limit_service),
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service)
) -> Dict[str, Any]:
"""
Upgrade subscription plan for a tenant
This endpoint handles:
- Plan upgrade validation
- Stripe subscription update (preserves trial status if in trial)
- Local database update
- Cache invalidation
- Token refresh for immediate UI update
Trial handling:
- If user is in trial, they remain in trial after upgrade
- The upgraded tier price will be charged when trial ends
"""
try:
# Step 1: Validate upgrade eligibility
validation = await limit_service.validate_plan_upgrade(tenant_id, new_plan)
if not validation.get("can_upgrade", False):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=validation.get("reason", "Cannot upgrade to this plan")
)
from app.repositories.subscription_repository import SubscriptionRepository
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
async with database_manager.get_session() as session:
subscription_repo = SubscriptionRepository(Subscription, session)
active_subscription = await subscription_repo.get_active_subscription(tenant_id)
if not active_subscription:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No active subscription found for this tenant"
)
old_plan = active_subscription.plan
is_trialing = active_subscription.status == 'trialing'
trial_ends_at = active_subscription.trial_ends_at
logger.info("Starting subscription upgrade",
extra={
"tenant_id": tenant_id,
"subscription_id": str(active_subscription.id),
"stripe_subscription_id": active_subscription.subscription_id,
"old_plan": old_plan,
"new_plan": new_plan,
"is_trialing": is_trialing,
"trial_ends_at": str(trial_ends_at) if trial_ends_at else None,
"user_id": current_user["user_id"]
})
# Step 2: Update Stripe subscription if Stripe subscription ID exists
stripe_updated = False
if active_subscription.subscription_id:
try:
# Use orchestration service to handle Stripe update with trial preservation
upgrade_result = await orchestration_service.orchestrate_plan_upgrade(
tenant_id=tenant_id,
new_plan=new_plan,
proration_behavior="none" if is_trialing else "create_prorations",
immediate_change=not is_trialing, # Don't change billing anchor if trialing
billing_cycle=billing_cycle
)
stripe_updated = True
logger.info("Stripe subscription updated successfully",
extra={
"tenant_id": tenant_id,
"stripe_subscription_id": active_subscription.subscription_id,
"upgrade_result": upgrade_result
})
except Exception as stripe_error:
logger.error("Failed to update Stripe subscription, falling back to local update only",
extra={"tenant_id": tenant_id, "error": str(stripe_error)})
# Continue with local update even if Stripe fails
# This ensures the user gets access to features immediately
# Step 3: Update local database
updated_subscription = await subscription_repo.update_subscription_plan(
str(active_subscription.id),
new_plan
)
# Preserve trial status if was trialing
if is_trialing and trial_ends_at:
# Ensure trial_ends_at is preserved after plan update
await subscription_repo.update_subscription_status(
str(active_subscription.id),
'trialing',
{'trial_ends_at': trial_ends_at}
)
await session.commit()
logger.info("Subscription plan upgraded successfully in database",
extra={
"tenant_id": tenant_id,
"subscription_id": str(active_subscription.id),
"old_plan": old_plan,
"new_plan": new_plan,
"stripe_updated": stripe_updated,
"preserved_trial": is_trialing,
"user_id": current_user["user_id"]
})
# Step 4: Invalidate subscription cache
redis_client = None
try:
from app.services.subscription_cache import get_subscription_cache_service
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
cache_service = get_subscription_cache_service(redis_client)
await cache_service.invalidate_subscription_cache(tenant_id)
logger.info("Subscription cache invalidated after upgrade",
extra={"tenant_id": tenant_id, "new_plan": new_plan})
except Exception as cache_error:
logger.error("Failed to invalidate subscription cache after upgrade",
extra={"tenant_id": tenant_id, "error": str(cache_error)})
# Step 5: Invalidate tokens for immediate UI refresh
try:
if redis_client:
await _invalidate_tenant_tokens(tenant_id, redis_client)
logger.info("Invalidated all tokens for tenant after subscription upgrade",
extra={"tenant_id": tenant_id})
except Exception as token_error:
logger.error("Failed to invalidate tenant tokens after upgrade",
extra={"tenant_id": tenant_id, "error": str(token_error)})
# Step 6: Publish subscription change event for other services
try:
from shared.messaging import UnifiedEventPublisher
event_publisher = UnifiedEventPublisher()
await event_publisher.publish_business_event(
event_type="subscription.changed",
tenant_id=tenant_id,
data={
"tenant_id": tenant_id,
"old_tier": old_plan,
"new_tier": new_plan,
"action": "upgrade",
"is_trialing": is_trialing,
"trial_ends_at": trial_ends_at.isoformat() if trial_ends_at else None,
"stripe_updated": stripe_updated
}
)
logger.info("Published subscription change event",
extra={"tenant_id": tenant_id, "event_type": "subscription.changed"})
except Exception as event_error:
logger.error("Failed to publish subscription change event",
extra={"tenant_id": tenant_id, "error": str(event_error)})
return {
"success": True,
"message": f"Plan successfully upgraded to {new_plan}" + (" (trial preserved)" if is_trialing else ""),
"old_plan": old_plan,
"new_plan": new_plan,
"new_monthly_price": updated_subscription.monthly_price,
"is_trialing": is_trialing,
"trial_ends_at": trial_ends_at.isoformat() if trial_ends_at else None,
"stripe_updated": stripe_updated,
"validation": validation,
"requires_token_refresh": True
}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to upgrade subscription plan",
extra={"tenant_id": tenant_id, "new_plan": new_plan, "error": str(e)})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to upgrade subscription plan"
)
# ============================================================================
# QUOTA & LIMIT CHECK ENDPOINTS
# ============================================================================
@router.get("/tenants/{tenant_id}/subscription/limits/locations",
response_model=Dict[str, Any],
summary="Check location limits")
async def check_location_limits(
tenant_id: str = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
limit_service: SubscriptionLimitService = Depends(get_subscription_limit_service)
) -> Dict[str, Any]:
"""Check if tenant can add another location"""
try:
result = await limit_service.can_add_location(tenant_id)
return result
except Exception as e:
logger.error("Failed to check location limits",
extra={"tenant_id": tenant_id, "error": str(e)})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to check location limits"
)
@router.get("/tenants/{tenant_id}/subscription/limits/products",
response_model=Dict[str, Any],
summary="Check product limits")
async def check_product_limits(
tenant_id: str = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
limit_service: SubscriptionLimitService = Depends(get_subscription_limit_service)
) -> Dict[str, Any]:
"""Check if tenant can add another product"""
try:
result = await limit_service.can_add_product(tenant_id)
return result
except Exception as e:
logger.error("Failed to check product limits",
extra={"tenant_id": tenant_id, "error": str(e)})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to check product limits"
)
@router.get("/tenants/{tenant_id}/subscription/limits/users",
response_model=Dict[str, Any],
summary="Check user limits")
async def check_user_limits(
tenant_id: str = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
limit_service: SubscriptionLimitService = Depends(get_subscription_limit_service)
) -> Dict[str, Any]:
"""Check if tenant can add another user/member"""
try:
result = await limit_service.can_add_user(tenant_id)
return result
except Exception as e:
logger.error("Failed to check user limits",
extra={"tenant_id": tenant_id, "error": str(e)})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to check user limits"
)
@router.get("/tenants/{tenant_id}/subscription/limits/recipes",
response_model=Dict[str, Any],
summary="Check recipe limits")
async def check_recipe_limits(
tenant_id: str = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
limit_service: SubscriptionLimitService = Depends(get_subscription_limit_service)
) -> Dict[str, Any]:
"""Check if tenant can add another recipe"""
try:
result = await limit_service.can_add_recipe(tenant_id)
return result
except Exception as e:
logger.error("Failed to check recipe limits",
extra={"tenant_id": tenant_id, "error": str(e)})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to check recipe limits"
)
@router.get("/tenants/{tenant_id}/subscription/limits/suppliers",
response_model=Dict[str, Any],
summary="Check supplier limits")
async def check_supplier_limits(
tenant_id: str = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
limit_service: SubscriptionLimitService = Depends(get_subscription_limit_service)
) -> Dict[str, Any]:
"""Check if tenant can add another supplier"""
try:
result = await limit_service.can_add_supplier(tenant_id)
return result
except Exception as e:
logger.error("Failed to check supplier limits",
extra={"tenant_id": tenant_id, "error": str(e)})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to check supplier limits"
)
# ============================================================================
# PAYMENT MANAGEMENT ENDPOINTS
# ============================================================================
@router.get("/tenants/{tenant_id}/subscription/payment-method",
response_model=Dict[str, Any],
summary="Get payment method")
async def get_payment_method(
tenant_id: str = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service)
) -> Dict[str, Any]:
"""Get current payment method for a subscription"""
try:
result = await orchestration_service.get_payment_method(tenant_id)
# Ensure we always return a proper response structure
if result is None:
return {
"brand": None,
"last4": None,
"exp_month": None,
"exp_year": None
}
return result
except Exception as e:
logger.error("Failed to get payment method",
extra={"tenant_id": tenant_id, "error": str(e)})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get payment method"
)
@router.post("/tenants/{tenant_id}/subscription/payment-method",
response_model=Dict[str, Any],
summary="Update payment method")
async def update_payment_method(
tenant_id: str = Path(..., description="Tenant ID"),
payment_method_id: str = Query(..., description="New payment method ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service)
) -> Dict[str, Any]:
"""Update the default payment method for a subscription"""
try:
result = await orchestration_service.update_payment_method(tenant_id, payment_method_id)
return result
except Exception as e:
logger.error("Failed to update payment method",
extra={"tenant_id": tenant_id, "error": str(e)})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update payment method"
)
@router.get("/tenants/{tenant_id}/subscription/invoices",
response_model=Dict[str, Any],
summary="Get invoices")
async def get_invoices(
tenant_id: str = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service)
) -> Dict[str, Any]:
"""Get invoice history for a tenant"""
try:
result = await orchestration_service.get_invoices(tenant_id)
return result
except Exception as e:
logger.error("Failed to get invoices",
extra={"tenant_id": tenant_id, "error": str(e)})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get invoices"
)
# ============================================================================
# SETUP INTENT VERIFICATION
# ============================================================================
@router.get("/setup-intents/{setup_intent_id}/verify",
response_model=Dict[str, Any],
summary="Verify SetupIntent status")
async def verify_setup_intent(
setup_intent_id: str = Path(..., description="SetupIntent ID"),
request: Request = None,
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service)
) -> Dict[str, Any]:
"""
Verify SetupIntent status for registration flow
"""
try:
logger.info("Verifying SetupIntent for registration",
extra={"setup_intent_id": setup_intent_id})
result = await orchestration_service.verify_setup_intent_for_registration(setup_intent_id)
logger.info("SetupIntent verification completed",
extra={
"setup_intent_id": setup_intent_id,
"status": result.get('status')
})
return {
"success": True,
"setup_intent_id": setup_intent_id,
"status": result.get('status'),
"payment_method_id": result.get('payment_method_id'),
"customer_id": result.get('customer_id'),
"requires_action": result.get('requires_action', False),
"action_type": result.get('action_type'),
"client_secret": result.get('client_secret'),
"message": result.get('message', 'SetupIntent verification completed successfully')
}
except Exception as e:
logger.error("SetupIntent verification failed",
extra={"error": str(e), "setup_intent_id": setup_intent_id},
exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"SetupIntent verification failed: {str(e)}"
) from e
# ============================================================================
# PAYMENT CUSTOMER CREATION
# ============================================================================
@router.post("/payment-customers/create",
response_model=Dict[str, Any],
summary="Create payment customer")
async def create_payment_customer(
user_data: Dict[str, Any],
payment_method_id: Optional[str] = Query(None, description="Optional payment method ID"),
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service)
) -> Dict[str, Any]:
"""
Create a payment customer in the payment provider
This endpoint is designed for service-to-service communication from auth service
during user registration.
"""
try:
from app.services.payment_service import PaymentService
payment_service = PaymentService()
logger.info("Creating payment customer via service-to-service call",
extra={"email": user_data.get('email'), "user_id": user_data.get('user_id')})
customer = await payment_service.create_customer(user_data)
logger.info("Payment customer created successfully",
extra={"customer_id": customer.id, "email": customer.email})
payment_method_details = None
if payment_method_id:
try:
payment_method = await payment_service.update_payment_method(
customer.id,
payment_method_id
)
payment_method_details = {
"id": payment_method.id,
"type": payment_method.type,
"brand": payment_method.brand,
"last4": payment_method.last4,
"exp_month": payment_method.exp_month,
"exp_year": payment_method.exp_year
}
logger.info("Payment method attached to customer",
extra={"customer_id": customer.id, "payment_method_id": payment_method.id})
except Exception as e:
logger.warning("Failed to attach payment method to customer",
extra={"customer_id": customer.id, "error": str(e), "payment_method_id": payment_method_id})
return {
"success": True,
"payment_customer_id": customer.id,
"payment_method": payment_method_details,
"customer": {
"id": customer.id,
"email": customer.email,
"name": customer.name,
"created_at": customer.created_at.isoformat()
}
}
except Exception as e:
logger.error("Failed to create payment customer via service-to-service call",
extra={"error": str(e), "email": user_data.get('email'), "user_id": user_data.get('user_id')})
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create payment customer: {str(e)}"
)
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
async def _invalidate_tenant_tokens(tenant_id: str, redis_client):
"""
Invalidate all tokens for users in this tenant.
Forces re-authentication to get fresh subscription data.
"""
try:
changed_timestamp = datetime.now(timezone.utc).timestamp()
await redis_client.set(
f"tenant:{tenant_id}:subscription_changed_at",
str(changed_timestamp),
ex=86400 # 24 hour TTL
)
logger.info("Set subscription change timestamp for token invalidation",
extra={"tenant_id": tenant_id, "timestamp": changed_timestamp})
except Exception as e:
logger.error("Failed to invalidate tenant tokens",
extra={"tenant_id": tenant_id, "error": str(e)})
raise